// Arup Guha
// 6/23/04
// Class to manage the state matrix for AES, for Program #1.

import java.io.*;

public class AESmatrix {

  private AESpoly[][] mat;

  // Stores the multiplicative matrix for the Mix Columns part.
  final static int[] colVals = {2, 1, 1, 3, 3, 2, 1, 1, 
			        1, 3, 2, 1, 1, 1, 3, 2};

  // Stores the multiplicative inverse of colVals
  final static int[] invColVals = {14, 9, 13, 11, 11, 14, 9, 13,
                                   13, 11, 14, 9, 9, 13, 11, 14};

  // Creates a default AESmatrix object.
  public AESmatrix() {
    mat = new AESpoly[4][4];
  }

  // Creates an AESmatrix object with the array of byte values passed in.
  public AESmatrix(int[] vals) {

    mat = new AESpoly[4][4];
    for (int i=0; i<4; i++)
      for (int j=0; j<4; j++)
        mat[j][i] = new AESpoly(vals[4*i+j]);
  }

  // Creates an AESmatrix object from part of an int array storing an
  // AES message.
  public AESmatrix(int[] message, int startindex) {

    int i;
    mat = new AESpoly[4][4];

    // Loop through the next four ints, taking each byte individually.
    for (i=0; i<4; i++) {
      mat[0][i] = new AESpoly(message[startindex+i] >> 24);
      mat[1][i] = new AESpoly((message[startindex+i] & 0x00FF0000) >> 16);
      mat[2][i] = new AESpoly((message[startindex+i] & 0x0000FF00) >> 8);
      mat[3][i] = new AESpoly(message[startindex+i] & 0x000000FF);
    }
  }

  // Sets a particular entry in the current object to val.
  public void setPoly(int row, int col, AESpoly val) {
    mat[row][col] = val;
  }

  // Performs the Shift Rows operation on the current object.
  public AESmatrix shiftRows() {

    // Store the answer in this new AESmatrix object.
    AESmatrix ans = new AESmatrix();

    // Copy the appropriate values into the appropriate slots of ans.
    for (int i=0; i<4; i++) 
      for (int j=0; j<4; j++)
        ans.setPoly(i, j, mat[i][(i+j)%4]);

    return ans; // Return the shifted matrix.
  }

  // Performs the inverse of the Shift Rows operation.
  public AESmatrix invShiftRows() {

    AESmatrix ans = new AESmatrix(); // Store the answer here.

    // Shift the Rows back.
    for (int i=0; i<4; i++) 
      for (int j=0; j<4; j++)
        ans.setPoly(i, (i+j)%4, mat[i][j]);

    return ans; // Return the answer.
  }

  // Performs the Mix Columns operation on the currect object.
  public AESmatrix mixCols() {

    // Set up the multiplicative AESmatrix.
    AESmatrix colFactor = new AESmatrix(colVals);

    AESmatrix ans = new AESmatrix(); // Store the answer here.

    // Determine the value for each entry.
    for (int i=0; i<4; i++) {
      for (int j=0; j<4; j++) {

        // Add into the answer for the current entry, using normal
        // matrix multiplication with AESpoly operations.
        AESpoly temp = new AESpoly(0);
        for (int k=0; k<4; k++)
          temp = temp.add((colFactor.mat[i][k]).mult(mat[k][j]));

        ans.setPoly(i, j, temp); // Set the element.
      }
    }
    return ans; // Return the new state matrix.
  }

  // Inverts the Mix Columns operation on the current object.
  public AESmatrix invMixCols() {

    // Sets up the multiplicative AESmatrix.
    AESmatrix colFactor = new AESmatrix(invColVals);

    AESmatrix ans = new AESmatrix(); // Stores the answer.

    // Loop through each matrix element.
    for (int i=0; i<4; i++) {
      for (int j=0; j<4; j++) {

        // Calculate the current entry using normal matrix multiplication
        // with AESpoly operations.
        AESpoly temp = new AESpoly(0);
        for (int k=0; k<4; k++)
          temp = temp.add((colFactor.mat[i][k]).mult(mat[k][j]));

        ans.setPoly(i, j, temp); // Set the element.
      }
    }
    return ans; // Return the new state matrix.
  }

  // Writes the updated statematrix back into the message matrix.
  public void writeAns(int[] message, int startindex) {
    
    // Loop through the four ints.
    for (int i=0; i<4; i++) {

      // Determine the next int by going through each HEX char.
      int val = 0;
      for (int j=0; j<4; j++)
        val = (val << 8) + mat[j][i].getVal();

      message[startindex+i] = val; // Write the next int to the message.
    }
  }

}
