# Arup Guha
# 10/30/2020
# Example of RSA (right now only intended for small cases.)

import random

# Returns base to the power power mod mod.
def modPow(base,power,mod):

    # Base case.
    if power == 0:
        return 1%mod

    # Time savings by getting square root and squaring.
    if power%2 == 0:
        tmp = modPow(base, power//2, mod)
        return (tmp*tmp)%mod

    # Regular case.
    return (modPow(base,power-1,mod)*base)%mod

# Returns the greatest common divisor of a and b.
def gcd(a,b):
    if b == 0:
        return a
    return gcd(b, a%b)

# Returns a inverse mod n. Does this REALLY slowly...will be replaced.
def modInv(a,n):

    # Yeah, this is really bad; straight up brute force...
    for i in range(1, n):
        if (a*i)%n == 1:
            return i
    return -1

def main():

    # Get p and q.
    p = int(input("Enter a prime number p."))
    q = int(input("Enter a different prime number q."))

    # Calculate n and phi.
    n = p*q
    print("Public key n is",n)
    phin = (p-1)*(q-1)

    # Find a value for e.
    e = random.randint(2, phin-1)
    while gcd(e,phin) != 1:
        e = random.randint(2, phin-1)

    # Calculate d.
    d = modInv(e,phin)
    print("The public key e is", e)
    print("The corresponding private key d is", d)

    # Ask user to enter plaintext.
    plain = int(input("Enter a plaintext in between 1 and "+str(n)+"."))

    # Encrypt via modular exponentiation.
    cipher = modPow(plain, e, n)
    print("The ciphertext is", cipher)

    # Decrypt with secret key.
    plainback = modPow(cipher, d, n)
    print("We recovered", plainback)

# Go!
main()
    
    
