aes02.py

# =========================================================
# encrypt/decrypt using AES secret key
# ---------------------------------------------------------
# From: www.YouTube.com/watch-UB2VX4vNUa0
# Python AES Encryption/Decription using PyCrypto tutorial
# =========================================================

from Crypto.Cipher import AES
from Crypto import Random
import sys
import os
import os.path
from os import listdir
from os.path import isfile, join

class Encryptor:

    def __init__(self, key):
        self.key = key

    # -----------------------------------------------------
    # pad the end of the string with b'\0' characters
    # make string length a multiple of AES.block_size
    # -----------------------------------------------------

    def pad(self,s):
        return s + b'\0' * (AES.block_size - len(s) % AES.block_size)

    # -----------------------------------------------------
    # encrypt a string
    # -----------------------------------------------------

    def encrypt(self, message, key, key_size=256):
        message = self.pad(message)
        iv = Random.new().read(AES.block_size)
        cipher = AES.new(key, AES.MODE_CBC, iv)
        return iv + cipher.encrypt(message)

    # -----------------------------------------------------
    # encrypt a file
    # -----------------------------------------------------

    def encrypt_file(self, file_name):
        with open(file_name,'rb') as fo:
            plainText = fo.read()
        enc = self.encrypt(plainText,self.key)
        with open(file_name + '.enc','wb') as fo:
            fo.write(enc)
        os.remove(file_name)

    # -----------------------------------------------------
    #
    # -----------------------------------------------------

    def decrypt(self, cipherText, key):
        iv = cipherText[:AES.block_size]
        cipher = AES.new(key, AES.MODE_CBC, iv)
        plainText = cipher.decrypt(cipherText[AES.block_size:])
        return plainText.rstrip(b'\0')

    # -----------------------------------------------------
    #
    # -----------------------------------------------------

    def decrypt_file(self, file_name):
        with open(file_name, 'rb') as fo:
            cipherText = fo.read()
        dec = self.decrypt(cipherText, self.key)
        with open(file_name[:-4], 'wb') as fo:
            fo.write(dec)
        os.remove(file_name)

    # -----------------------------------------------------
    #
    # -----------------------------------------------------

    def getAllFiles(self):
        dir_path = os.path.dirname(os.path.realpath(__file__))
        dirs = []
        for dirName, subDirList, fileList in os.walk(dir_path):
            for fname in fileList:
                if (fname != 'aes02.py' and fname != 'data.txt.enc'):
                    dirs.append(dirName + '/' + fname)
        return dirs

    # -----------------------------------------------------
    #
    # -----------------------------------------------------

    def encrypt_all_files(self):
        dirs = self.getAllFiles()
        for file_name in dirs:
            self.encrypt_file(file_name)

    # -----------------------------------------------------
    #
    #------------------------------------------------------

    def decrypt_all_files(self):
        dirs = self.getAllFiles()
        for file_name in dirs:
            self.decrypt_file(file_name)


# ---------------------------------------------------------
# main
# ---------------------------------------------------------

key = b'\xbf\xc0\x85)\x10nc\x94\x02)j\xdf\xcb\xc4\x94\x9d(\x9e[EX\xc8\xd5\xbfI{\xa2$\x05(\xd5\x18'

##print('Key Len: ', len(key))

# get name of script

pyscript = os.path.basename(sys.argv[0])

enc = Encryptor(key)

clear = lambda:os.system('clear')

if os.path.isfile('data.txt.enc'):
    while True:
        password = str(input('Enter password: '))
        enc.decrypt_file('data.txt.enc')
        p =''
        with open('data.txt') as f:
            p = f.readlines()
        if p[0] == password:
            enc.encrypt_file('data.txt')
            break

    while True:
        clear()
        choice = int(input(
            '''1. Press '1' to encrypt a file.
            2. Press '2' to decrypt a file.
            3. Press '3' to encrypt all files in the directory.
            4. Press '4' to decrypt all files in the directory.
            5. Press '5' to exit.
            Enter option: '''))
        clear()
        if choice == 1:
            enc.encrypt_file(str(input('Enter name of file to encrypt: ')))
        elif choice == 2:
            enc.decrypt_file(str(input('Enter name of file to decrypt: ')))
        elif choice == 3:
            enc.encrypt_all_files()
        elif choice == 4:
            enc.decrypt_all_files()
        elif choice == 5:
            exit()
        else:
            print('Plese select a valid option')

else:
    while True:
        clear()
        password = str(input('Setting up stuff. Enter a password that will be used for decryption: '))
        repassword = str(input('Confirm password: '))
        if password == repassword:
            break
        else:
            print('Password mismatch!')
    f = open('data.txt','w+')
    f.write(password)
    f.close()
    enc.encrypt_file('data.txt')
    print('Please restart the program to finish the startup')