// Arup Guha
// 9/28/2023
// Code to solve CIS 3362 Homework #3 Problem #5

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#define NUMCHOICE 10
#define N 3

void hillmult(int m[][N], int plain[], char cipher[]);
void go(int m[][3], int k);
void process(int m[][N]);
int contains(char* full, char* pat);
void print(int m[][N]);
int det(int m[][N]);

int completed;

// All numbers in the decryption matrix come from here.
int POSSIBLE[NUMCHOICE] = {0,4,5,6,8,11,16,21,22,25};

// Ciphertext.
char inp[1000] = "HJLMYYBHFIKQSQEQAKWMSRYRKZPESYBGSCVDBHFBFLDRYKCPBHFQIAXTXDKKEERJTLYPKOTSBDAMXGBHFACMSVXXNJQIVHHTECOJAADHMGZHTQWWQOWUPXMECNJESHUUTNGUBDAOFVWQDUQHKOWJSAYFBOIVJBVOBGACFMKNMSERKYGYGRYERQAWQOBHFPFWOGAIIIVCQQOXYFICVDBHFGRMYKFIIIVCQOFVMFZYYYKBUMRTWHXUOOBBEQRYWUOLVFJTQTONUWZDPDBHFXDUGKNWMRNGUBHFOKOCKHSHKSWTVLZECBFPZLCKHSJCJJASB";

// Run it!
int main(void) {
    completed = 0;
    int m[N][N];
    go(m, 0);
    return 0;
}

// Recursive odometer function to try all matrices with the first k slots fixed.
void go(int m[][3], int k) {

    // All slots are fixed, process this matrix.
    if (k == N*N) {
        process(m);
        completed++;

        // This was so I could get a status update every million possibilities.
        if (completed%1000000 == 0) printf("Did %d\n", completed);
        return;
    }

    // Try each of the 10 possible numbers in slot k and recurse.
    for (int i=0; i<NUMCHOICE; i++) {
        m[k/N][k%N] = POSSIBLE[i];
        go(m, k+1);
    }
}

// Only works for 3 by 3.
int det(int m[][N]) {
    int res = 0;
    for (int i=0; i<3; i++)
        res = res + m[0][i]*m[1][(i+1)%3]*m[2][(i+2)%3];
    for (int i=0; i<3; i++)
        res = res - m[0][(i+2)%3]*m[1][(i+1)%3]*m[2][i%3];
    res = (res + 260000)%26;
    return res;
}

// Assumes input string is a multiple of N(3).
void process(int m[][N]) {

    // We can throw out any matrix with a determinant that isn't relatively prime with 26.
    int mydet = det(m);
    if (mydet%2 == 0 || mydet%13 == 0) return;

    // Store my decrypted text here.
    int len = strlen(inp);
    char res[1000];
    res[len] = '\0';

    // Go block by block.
    for (int i=0; i<len; i+=N) {

        // My code takes in numbers so convert the input to numbers.
        int p[N];
        for (int j=0; j<N; j++) p[j] = inp[i+j]-'A';

        // Calculate the matrix product and store the corresponding characters in res.
        char tmp[N];
        hillmult(m, p, tmp);
        for (int j=0; j<N; j++) res[i+j] = tmp[j];
    }

    // If these two words are in the supposed plaintext, print so I can examine.
    if (contains(res, "PRIZE") && contains(res, "THE")){
        printf("%s\n", res);
        print(m);
    }
}

// Returns true iff pat is contained in full.
int contains(char* full, char* pat) {

    // Get both lengths.
    int n = strlen(full);
    int m = strlen(pat);

    // i is starting position in full for the match.
    for (int i=0; i<=n-m; i++) {

        // Skip these.
        if (full[i] != pat[0]) continue;

        // Look for discrepancies.
        int flag = 1;
        for (int j=1; j<m; j++) {
            if (full[i+j] != pat[j]) {
                flag = 0;
                break;
            }
        }

        // If we found none, it's a match.
        if (flag) return 1;
    }

    // If we get here, we never found a match.
    return 0;
}

// Prints the matrix m.
void print(int m[][N]) {

    for (int i=0; i<N; i++) {
        for (int j=0; j<N; j++)
            printf("%5d", m[i][j]);
        printf("\n");
    }
}

// Written for 3 x 3 only,
void hillmult(int m[][N], int plain[], char cipher[]) {

    // i is the row of the matrix m we are multiplying by.
    for (int i=0; i<N; i++) {

        // Set this to 0 since malloc doesn't.
        int res = 0;

        // This calculates the "dot product" for this one letter.
        for (int j=0; j<N; j++)
            res = (res + m[i][j]*plain[j])%26;

        cipher[i] = (char)('A'+res);
    }
}


