// Template for COT 5937 Programming Assignment #1
// Provided: An utility function to read in a normal text file into a
// hex message string, as well as a function that takes the hex version of
// a plaintext message and properly writes to a specified file the actual
// plaintext.

import java.io.*;
import java.util.*;

public class AES {

  // Constants to identify the current status of the message.
  final static boolean PLAIN = false;
  final static boolean CIPHER = true;

  private int[] message; // Stores the message in HEX.
  private String inputfile; // Stores the name of the input file.
  private String outputfile; // Stores the name of the output file.
  private int[] key; // Stores the key.
  private boolean msgstatus; // Stores whether message is currently the
                             // plain text or cipher text.

  private int msglength; // Stores the length of the message in bytes;
  private int[] sbox; // Stores the sbox used for the substitution.
  private int[] inv_sbox; // Stores the inverse sbox.

  public static void main(String[] args) throws IOException {

    // A very small test of the given methods. It allows the user to
    // Read in a file, write out its HEX equivalent, and then write the
    // file back to a separate input file. No encryption is being done
    // here.
    BufferedReader stdin = new BufferedReader(new InputStreamReader(System.in));
    System.out.println("Enter the input file.");
    String input = stdin.readLine();
    String output = "cipher.txt";
    AES test = new AES(input, output, "sbox.txt","invsbox.txt");
    test.readMessage();
    test.writeOutput();
    test.writePlainFile("inputtest.txt");
  }

  // Creates a new AES object.
  public AES(String inp, String out, String sboxfile,
             String invsboxfile) throws IOException {

    inputfile = inp;
    outputfile = out;

    // Divides by 4 since an int stores 4 bytes.
    message = new int[getMessageLength()/4];
    readSbox(sboxfile);
    readInvSbox(invsboxfile);
  }

  // Reads in the sbox from file. It is assumed that file contains 16
  // lines, each of which stores 16 HEX values as strings of length 2,
  // separated by white space.
  private void readSbox(String file) throws IOException {

    BufferedReader fin = new BufferedReader(new FileReader(file));
    sbox = new int[256];

    // Read in each line.
    for (int i=0; i<16; i++) {
      String line = fin.readLine();
      StringTokenizer tok = new StringTokenizer(line);

      // Split out, convert each value, and store.
      for (int j=0; j<16; j++) {
        String temp = tok.nextToken();
        sbox[16*i+j] = 16*hexVal(temp.charAt(0))+hexVal(temp.charAt(1));
      }
    }
    fin.close();
  }

  // Works exactly the same as readSbox.
  private void readInvSbox(String file) throws IOException {

    BufferedReader fin = new BufferedReader(new FileReader(file));
    inv_sbox = new int[256];

    // Read in each line.
    for (int i=0; i<16; i++) {
      String line = fin.readLine();
      StringTokenizer tok = new StringTokenizer(line);

      // Split out, convert each value, and store.
      for (int j=0; j<16; j++) {
        String temp = tok.nextToken();
        inv_sbox[16*i+j] = 16*hexVal(temp.charAt(0))+hexVal(temp.charAt(1));
      }
    }
    fin.close();
  }

  // Determines the length of the file in bytes, and then returns the next
  // smallest multiple of 16 greater than or equal to that. This value
  // will be the adjusted message size after padding.
  private int getMessageLength() throws IOException {

    BufferedReader fin = new BufferedReader(new FileReader(inputfile));
    int bytecount = 0;
    while (fin.ready()) {
      int c = fin.read();
      bytecount++;
    }
    fin.close();

    msglength = bytecount;

    // Rounds up to the nearest multiple of 16 bytes.
    return 16*((bytecount+15)/16);
  }

  // Reads in a the input file into message, storing the input file in
  // hexadecimal characters.
  public void readMessage() throws IOException {

    BufferedReader fin = new BufferedReader(new FileReader(inputfile));

    int icnt, bytecnt;

    // Initialize the message buffer.
    for (icnt=0; icnt<message.length; icnt++)    
      message[icnt] = 0;

    // Read in one character at a time, storing each in the appropriate
    // integer in the message. (4 bytes = int)
    for (bytecnt=0; bytecnt<msglength; bytecnt++) {
      int c = fin.read();
      message[bytecnt/4] = c + (message[bytecnt/4] << 8);
    }

    // Take care of the padding with spaces.
    for (bytecnt=msglength; bytecnt<4*message.length; bytecnt++)
      message[bytecnt/4] = (message[bytecnt/4] << 8) + (int)(' ');

    fin.close();
  }

  // Converts an integer(0-15) to the appropriate HEX character.
  public static char convToHex(int d) {
    if (d < 10)
      return (char)(d+'0');
    else
      return (char)(d-10+'A');
  }

  // Returns the integer value of a given HEX character.
  public static int hexVal(char c) {
    if (c >= '0' && c <= '9')
      return (int)(c-'0');
    else
      return (int)(c-'A'+10);
  }

  // Should only be called if the message is in CIPHER status. This
  // method writes out the ciphertext to the output file, writing 64
  // hex characters per line. This corresponds to 4 blocks of ciphertext
  // per line.
  public void writeOutput() throws IOException {

    BufferedWriter fout = new BufferedWriter(new FileWriter(outputfile));

    // Write output for each int, one at a time.
    for (int i=0; i<message.length; i++) {
      char[] hex = new char[8];
      int temp = message[i];

      // Determine the value of the 8 HEX chars that represent the int.
      for (int j=7; j>=0; j--) {
        hex[j] = convToHex(temp & 15); 
        temp = temp >> 4;       
      }

      // Write them all out.
      for (int j=0; j<8; j++)
        fout.write(hex[j]);

      // Go to the next line if 64 HEX characters have been written out.
      if (i%8 == 7)
        fout.write('\n');
    }
    fout.close();
  }

  // This method should only be called in PLAIN mode. It writes out the
  // plaintext in NORMAL text mode instead in HEX to the file passed into
  // the method.
  public void writePlainFile(String plain) throws IOException {

    BufferedWriter fout = new BufferedWriter(new FileWriter(plain));    

    // Goes through the message.
    for (int i=0; i<message.length; i++) {
      int temp = message[i];
      char[] letters = new char[4];

      // Converts each int into 4 bytes to output.
      for (int j=3; j>=0; j--) {
        letters[j] = (char)(temp & 255);
        temp = temp >> 8;
      }

      // Write the bytes to the output file.
      for (int j=0; j<4; j++)
        fout.write(letters[j]); 
    }
    fout.close();
  }

  // Runs AES encryption on the message, assuming that the mode is PLAIN.
  // After completion, the mode is changed to CIPHER.
  public void encrypt() {
   
  }

  // Runs AES decryption on the message, assuming that the mode is CIPHER.
  // After completion, the mode is changed to PLAIN.
  public void decrypt() {

  }

  public boolean setKey(String the_key) {

  }

  public void keyExpansion() {

  }

  // Assumes rounds for the key are labelled 0 through 10.
  private void addRoundKey(int which_round) {

  }

  private void subBytes() {

  }

  private void subInvBytes() {

  }

  private void shiftRow() {

  }

  private void invShiftRow() {

  }

  private void mixCol() {

  }

  private void invMixCol() {

  }

}
