Skip to content

fix(mypy): type annotations for cipher algorithms #4306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
python -m pip install mypy pytest-cov -r requirements.txt
# FIXME: #4052 fix mypy errors in the exclude directories and remove them below
- run: mypy --ignore-missing-imports
--exclude '(ciphers|conversions|data_structures|digital_image_processing|dynamic_programming|graphs|linear_algebra|maths|matrix|other|project_euler|scripts|searches|strings*)/$' .
--exclude '(conversions|data_structures|digital_image_processing|dynamic_programming|graphs|linear_algebra|maths|matrix|other|project_euler|scripts|searches|strings*)/$' .
- name: Run tests
run: pytest --doctest-modules --ignore=project_euler/ --ignore=scripts/ --cov-report=term-missing:skip-covered --cov=. .
- if: ${{ success() }}
Expand Down
8 changes: 2 additions & 6 deletions ciphers/diffie_hellman.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ def generate_shared_key(self, other_key_str: str) -> str:
return sha256(str(shared_key).encode()).hexdigest()

@staticmethod
def is_valid_public_key_static(
local_private_key_str: str, remote_public_key_str: str, prime: int
) -> bool:
def is_valid_public_key_static(remote_public_key_str: int, prime: int) -> bool:
# check if the other public key is valid based on NIST SP800-56
if 2 <= remote_public_key_str and remote_public_key_str <= prime - 2:
if pow(remote_public_key_str, (prime - 1) // 2, prime) == 1:
Expand All @@ -257,9 +255,7 @@ def generate_shared_key_static(
local_private_key = int(local_private_key_str, base=16)
remote_public_key = int(remote_public_key_str, base=16)
prime = primes[group]["prime"]
if not DiffieHellman.is_valid_public_key_static(
local_private_key, remote_public_key, prime
):
if not DiffieHellman.is_valid_public_key_static(remote_public_key, prime):
raise ValueError("Invalid public key")
shared_key = pow(remote_public_key, local_private_key, prime)
return sha256(str(shared_key).encode()).hexdigest()
Expand Down
23 changes: 10 additions & 13 deletions ciphers/hill_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ class HillCipher:

to_int = numpy.vectorize(lambda x: round(x))

def __init__(self, encrypt_key: int):
def __init__(self, encrypt_key: numpy.ndarray) -> None:
"""
encrypt_key is an NxN numpy array
"""
self.encrypt_key = self.modulus(encrypt_key) # mod36 calc's on the encrypt key
self.check_determinant() # validate the determinant of the encryption key
self.decrypt_key = None
self.break_key = encrypt_key.shape[0]

def replace_letters(self, letter: str) -> int:
Expand Down Expand Up @@ -139,8 +138,8 @@ def encrypt(self, text: str) -> str:

for i in range(0, len(text) - self.break_key + 1, self.break_key):
batch = text[i : i + self.break_key]
batch_vec = [self.replace_letters(char) for char in batch]
batch_vec = numpy.array([batch_vec]).T
vec = [self.replace_letters(char) for char in batch]
batch_vec = numpy.array([vec]).T
batch_encrypted = self.modulus(self.encrypt_key.dot(batch_vec)).T.tolist()[
0
]
Expand All @@ -151,7 +150,7 @@ def encrypt(self, text: str) -> str:

return encrypted

def make_decrypt_key(self):
def make_decrypt_key(self) -> numpy.ndarray:
"""
>>> hill_cipher = HillCipher(numpy.array([[2, 5], [1, 6]]))
>>> hill_cipher.make_decrypt_key()
Expand Down Expand Up @@ -184,17 +183,15 @@ def decrypt(self, text: str) -> str:
>>> hill_cipher.decrypt('85FF00')
'HELLOO'
"""
self.decrypt_key = self.make_decrypt_key()
decrypt_key = self.make_decrypt_key()
text = self.process_text(text.upper())
decrypted = ""

for i in range(0, len(text) - self.break_key + 1, self.break_key):
batch = text[i : i + self.break_key]
batch_vec = [self.replace_letters(char) for char in batch]
batch_vec = numpy.array([batch_vec]).T
batch_decrypted = self.modulus(self.decrypt_key.dot(batch_vec)).T.tolist()[
0
]
vec = [self.replace_letters(char) for char in batch]
batch_vec = numpy.array([vec]).T
batch_decrypted = self.modulus(decrypt_key.dot(batch_vec)).T.tolist()[0]
decrypted_batch = "".join(
self.replace_digits(num) for num in batch_decrypted
)
Expand All @@ -203,12 +200,12 @@ def decrypt(self, text: str) -> str:
return decrypted


def main():
def main() -> None:
N = int(input("Enter the order of the encryption key: "))
hill_matrix = []

print("Enter each row of the encryption key with space separated integers")
for i in range(N):
for _ in range(N):
row = [int(x) for x in input().split()]
hill_matrix.append(row)

Expand Down
18 changes: 9 additions & 9 deletions ciphers/mixed_keyword_cypher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,32 +29,32 @@ def mixed_keyword(key: str = "college", pt: str = "UNIVERSITY") -> str:
# print(temp)
alpha = []
modalpha = []
for i in range(65, 91):
t = chr(i)
for j in range(65, 91):
t = chr(j)
alpha.append(t)
if t not in temp:
temp.append(t)
# print(temp)
r = int(26 / 4)
# print(r)
k = 0
for i in range(r):
t = []
for _ in range(r):
s = []
for j in range(len_temp):
t.append(temp[k])
s.append(temp[k])
if not (k < 25):
break
k += 1
modalpha.append(t)
modalpha.append(s)
# print(modalpha)
d = {}
j = 0
k = 0
for j in range(len_temp):
for i in modalpha:
if not (len(i) - 1 >= j):
for m in modalpha:
if not (len(m) - 1 >= j):
break
d[alpha[k]] = i[j]
d[alpha[k]] = m[j]
if not k < 25:
break
k += 1
Expand Down
8 changes: 6 additions & 2 deletions ciphers/mono_alphabetic_ciphers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Literal

LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"


def translate_message(key, message, mode):
def translate_message(
key: str, message: str, mode: Literal["encrypt", "decrypt"]
) -> str:
"""
>>> translate_message("QWERTYUIOPASDFGHJKLZXCVBNM","Hello World","encrypt")
'Pcssi Bidsm'
Expand Down Expand Up @@ -40,7 +44,7 @@ def decrypt_message(key: str, message: str) -> str:
return translate_message(key, message, "decrypt")


def main():
def main() -> None:
message = "Hello World"
key = "QWERTYUIOPASDFGHJKLZXCVBNM"
mode = "decrypt" # set to 'encrypt' or 'decrypt'
Expand Down
2 changes: 1 addition & 1 deletion ciphers/morse_code_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def decrypt(message: str) -> str:
return decipher


def main():
def main() -> None:
message = "Morse code here"
result = encrypt(message.upper())
print(result)
Expand Down
9 changes: 5 additions & 4 deletions ciphers/onepad_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@


class Onepad:
def encrypt(self, text: str) -> ([str], [int]):
@staticmethod
def encrypt(text: str) -> tuple[list[int], list[int]]:
"""Function to encrypt text using pseudo-random numbers"""
plain = [ord(i) for i in text]
key = []
Expand All @@ -14,14 +15,14 @@ def encrypt(self, text: str) -> ([str], [int]):
key.append(k)
return cipher, key

def decrypt(self, cipher: [str], key: [int]) -> str:
@staticmethod
def decrypt(cipher: list[int], key: list[int]) -> str:
"""Function to decrypt text using pseudo-random numbers."""
plain = []
for i in range(len(key)):
p = int((cipher[i] - (key[i]) ** 2) / key[i])
plain.append(chr(p))
plain = "".join([i for i in plain])
return plain
return "".join([i for i in plain])


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions ciphers/playfair_cipher.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools
import string
from typing import Generator, Iterable


def chunker(seq, size):
def chunker(seq: Iterable[str], size: int) -> Generator[tuple[str, ...], None, None]:
it = iter(seq)
while True:
chunk = tuple(itertools.islice(it, size))
Expand Down Expand Up @@ -37,7 +38,7 @@ def prepare_input(dirty: str) -> str:
return clean


def generate_table(key: str) -> [str]:
def generate_table(key: str) -> list[str]:

# I and J are used interchangeably to allow
# us to use a 5x5 table (25 letters)
Expand Down
49 changes: 21 additions & 28 deletions ciphers/porta_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
}


def generate_table(key: str) -> [(str, str)]:
def generate_table(key: str) -> list[tuple[str, str]]:
"""
>>> generate_table('marvin') # doctest: +NORMALIZE_WHITESPACE
[('ABCDEFGHIJKLM', 'UVWXYZNOPQRST'), ('ABCDEFGHIJKLM', 'NOPQRSTUVWXYZ'),
Expand Down Expand Up @@ -60,30 +60,21 @@ def decrypt(key: str, words: str) -> str:
return encrypt(key, words)


def get_position(table: [(str, str)], char: str) -> (int, int) or (None, None):
def get_position(table: tuple[str, str], char: str) -> tuple[int, int]:
"""
>>> table = [
... ('ABCDEFGHIJKLM', 'UVWXYZNOPQRST'), ('ABCDEFGHIJKLM', 'NOPQRSTUVWXYZ'),
... ('ABCDEFGHIJKLM', 'STUVWXYZNOPQR'), ('ABCDEFGHIJKLM', 'QRSTUVWXYZNOP'),
... ('ABCDEFGHIJKLM', 'WXYZNOPQRSTUV'), ('ABCDEFGHIJKLM', 'UVWXYZNOPQRST')]
>>> get_position(table, 'A')
(None, None)
>>> get_position(generate_table('marvin')[0], 'M')
(0, 12)
"""
if char in table[0]:
row = 0
else:
row = 1 if char in table[1] else -1
return (None, None) if row == -1 else (row, table[row].index(char))
# `char` is either in the 0th row or the 1st row
row = 0 if char in table[0] else 1
col = table[row].index(char)
return row, col


def get_opponent(table: [(str, str)], char: str) -> str:
def get_opponent(table: tuple[str, str], char: str) -> str:
"""
>>> table = [
... ('ABCDEFGHIJKLM', 'UVWXYZNOPQRST'), ('ABCDEFGHIJKLM', 'NOPQRSTUVWXYZ'),
... ('ABCDEFGHIJKLM', 'STUVWXYZNOPQR'), ('ABCDEFGHIJKLM', 'QRSTUVWXYZNOP'),
... ('ABCDEFGHIJKLM', 'WXYZNOPQRSTUV'), ('ABCDEFGHIJKLM', 'UVWXYZNOPQRST')]
>>> get_opponent(table, 'A')
'A'
>>> get_opponent(generate_table('marvin')[0], 'M')
'T'
"""
row, col = get_position(table, char.upper())
if row == 1:
Expand All @@ -97,14 +88,16 @@ def get_opponent(table: [(str, str)], char: str) -> str:

doctest.testmod() # Fist ensure that all our tests are passing...
"""
ENTER KEY: marvin
ENTER TEXT TO ENCRYPT: jessica
ENCRYPTED: QRACRWU
DECRYPTED WITH KEY: JESSICA
Demo:

Enter key: marvin
Enter text to encrypt: jessica
Encrypted: QRACRWU
Decrypted with key: JESSICA
"""
key = input("ENTER KEY: ").strip()
text = input("ENTER TEXT TO ENCRYPT: ").strip()
key = input("Enter key: ").strip()
text = input("Enter text to encrypt: ").strip()
cipher_text = encrypt(key, text)

print(f"ENCRYPTED: {cipher_text}")
print(f"DECRYPTED WITH KEY: {decrypt(key, cipher_text)}")
print(f"Encrypted: {cipher_text}")
print(f"Decrypted with key: {decrypt(key, cipher_text)}")
10 changes: 5 additions & 5 deletions ciphers/rail_fence_cipher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def encrypt(input_string: str, key: int) -> str:
...
TypeError: sequence item 0: expected str instance, int found
"""
grid = [[] for _ in range(key)]
temp_grid: list[list[str]] = [[] for _ in range(key)]
lowest = key - 1

if key <= 0:
Expand All @@ -31,8 +31,8 @@ def encrypt(input_string: str, key: int) -> str:
for position, character in enumerate(input_string):
num = position % (lowest * 2) # puts it in bounds
num = min(num, lowest * 2 - num) # creates zigzag pattern
grid[num].append(character)
grid = ["".join(row) for row in grid]
temp_grid[num].append(character)
grid = ["".join(row) for row in temp_grid]
output_string = "".join(grid)

return output_string
Expand Down Expand Up @@ -63,7 +63,7 @@ def decrypt(input_string: str, key: int) -> str:
if key == 1:
return input_string

temp_grid = [[] for _ in range(key)] # generates template
temp_grid: list[list[str]] = [[] for _ in range(key)] # generates template
for position in range(len(input_string)):
num = position % (lowest * 2) # puts it in bounds
num = min(num, lowest * 2 - num) # creates zigzag pattern
Expand All @@ -84,7 +84,7 @@ def decrypt(input_string: str, key: int) -> str:
return output_string


def bruteforce(input_string: str) -> dict:
def bruteforce(input_string: str) -> dict[int, str]:
"""Uses decrypt function by guessing every key

>>> bruteforce("HWe olordll")[4]
Expand Down
2 changes: 1 addition & 1 deletion ciphers/rot13.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def dencrypt(s: str, n: int = 13) -> str:
return out


def main():
def main() -> None:
s0 = input("Enter message: ")

s1 = dencrypt(s0, 13)
Expand Down
Loading