Skip to content

Commit

Permalink
Clean up types
Browse files Browse the repository at this point in the history
  • Loading branch information
theonlypwner committed Sep 21, 2024
1 parent 247daa5 commit 30a9edf
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 53 deletions.
32 changes: 15 additions & 17 deletions crc32.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from crc32 import CRC32, CRC32Reverse, combine, reverse_bits, reciprocal

def get_poly(args):
def get_poly(args) -> int:
poly = parse_dword(args.poly)
if args.msb:
poly = reverse_bits(poly)
Expand All @@ -24,11 +24,11 @@ def get_poly(args):

def get_input(args):
if args.instr:
return tuple(args.instr.encode('utf-8'))
return args.instr.encode('utf-8')
with args.infile as f:
return tuple(f.read())
return f.read()

def parse_dword(x):
def parse_dword(x: str) -> int:
return int(x, 0) & 0xFFFFFFFF


Expand Down Expand Up @@ -191,10 +191,9 @@ def table_callback(args):

def reverse_callback(args):
permitted_characters = set(
map(ord, 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_')) # \w
b'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_') # \w

crc32 = CRC32(get_poly(args))
crc32_reverse = CRC32Reverse(crc32)
crc32_reverse = CRC32Reverse(get_poly(args))
# find reverse bytes
desired = parse_dword(args.desired)
accum = parse_dword(args.accum)
Expand All @@ -203,34 +202,33 @@ def reverse_callback(args):
for patch in patches:
text = ''
if all(p in permitted_characters for p in patch):
text = '{}{}{}{} '.format(*map(chr, patch))
text = patch.decode() + ' '
print('4 bytes: {}{{0x{:02x}, 0x{:02x}, 0x{:02x}, 0x{:02x}}}'.format(text, *patch), file=args.outfile)
checksum = crc32.calc(patch, accum)
checksum = crc32_reverse.calc(patch, accum)
print('verification checksum: 0x{:08x} ({})'.format(
checksum, 'OK' if checksum == desired else 'ERROR'), file=args.outfile)

def print_permitted_reverse(patch):
patches = crc32_reverse.find_reverse(desired, crc32.calc(patch, accum))
def print_permitted_reverse(patch: bytes):
patches = crc32_reverse.find_reverse(desired, crc32_reverse.calc(patch, accum))
for last_4_bytes in patches:
if all(p in permitted_characters for p in last_4_bytes):
patch2 = patch + last_4_bytes
print('{} bytes: {} ({})'.format(
len(patch2),
''.join(map(chr, patch2)),
'OK' if crc32.calc(patch2, accum) == desired else 'ERROR'), file=args.outfile)
patch2.decode(),
'OK' if crc32_reverse.calc(patch2, accum) == desired else 'ERROR'), file=args.outfile)

# 5-byte alphanumeric patches
for i in permitted_characters:
print_permitted_reverse((i,))
print_permitted_reverse(bytes([i]))
# 6-byte alphanumeric patches
for i in permitted_characters:
for j in permitted_characters:
print_permitted_reverse((i, j))
print_permitted_reverse(bytes([i, j]))


def undo_callback(args):
crc32 = CRC32(get_poly(args))
crc32_reverse = CRC32Reverse(crc32)
crc32_reverse = CRC32Reverse(get_poly(args))
# calculate checksum
accum = parse_dword(args.accum)
maxlen = int(args.len, 0)
Expand Down
59 changes: 30 additions & 29 deletions crc32/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections.abc import Iterable

class CRC32:
def __init__(self, poly: int):
def table_value(i: int) -> int:
Expand All @@ -9,24 +7,25 @@ def table_value(i: int) -> int:

self.table = tuple(map(table_value, range(256)))

def calc(self, data: Iterable[int], accum = 0):
def calc(self, data: bytes, accum = 0):
accum = ~accum
for b in data:
accum = self.table[(accum ^ b) & 0xFF] ^ ((accum >> 8) & 0x00FFFFFF)
accum = ~accum
return accum & 0xFFFFFFFF

class CRC32Reverse:
def __init__(self, crc32: CRC32):
self.table = crc32.table
class CRC32Reverse(CRC32):
def __init__(self, poly: int):
super().__init__(poly)

self.table_reverse = tuple(
tuple(j
for j in range(256)
if self.table[j] >> 24 == i)
for i in range(256)
)

def rewind(self, data, accum = 0) -> set[int]:
def rewind(self, data: bytes, accum: int) -> set[int]:
if not data:
return { accum }
stack = [(len(data), ~accum)]
Expand All @@ -43,24 +42,24 @@ def rewind(self, data, accum = 0) -> set[int]:
solutions.add((~prevCRC) & 0xFFFFFFFF)
return solutions

def find_reverse(self, desired: int, accum = 0):
def find_reverse(self, desired: int, accum = 0) -> set[bytes]:
solutions = set()
accum = ~accum
stack = [(~desired,)]
stack = [(~desired, b'')]
while stack:
node = stack.pop()
for j in self.table_reverse[(node[0] >> 24) & 0xFF]:
if len(node) == 4:
v, s = stack.pop()
for j in self.table_reverse[(v >> 24) & 0xFF]:
next_str = s + bytes([j])
if len(next_str) == 4:
a = accum
data = []
node = node[1:] + (j,)
data = bytearray()
for i in range(3, -1, -1):
data.append((a ^ node[i]) & 0xFF)
data.append((a ^ next_str[i]) & 0xFF)
a >>= 8
a ^= self.table[node[i]]
solutions.add(tuple(data))
a ^= self.table[next_str[i]]
solutions.add(bytes(data))
else:
stack.append(((node[0] ^ self.table[j]) << 8,) + node[1:] + (j,))
stack.append(((v ^ self.table[j]) << 8, next_str))
return solutions


Expand All @@ -70,30 +69,33 @@ def __init__(self, matrix):
self.matrix = matrix

@staticmethod
def identity():
def identity() -> 'Matrix':
return Matrix(tuple(1 << i for i in range(32)))

@staticmethod
def zero_operator(poly):
def zero_operator(poly: int) -> 'Matrix':
m = [poly]
n = 1
for _ in range(31):
m.append(n)
n <<= 1
return Matrix(tuple(m))

def multiply_vector(self, v, s = 0):
def multiply_vector(self, v: int, s = 0) -> int:
for c in self.matrix:
s ^= c & -(v & 1)
v >>= 1
if not v:
break
return s

def mul(self, matrix):
def mul(self, matrix: 'Matrix') -> 'Matrix':
return Matrix(tuple(map(self.multiply_vector, matrix.matrix)))

def combine(c1, c2, l2, n, poly):
def sqr(self) -> 'Matrix':
return self.mul(self)

def combine(c1: int, c2: int, l2: int, n: int, poly: int) -> int:
# The effect of feeding zero bits into the CRC32 state machine can be
# represented by matrix multiplication, allowing exponentiation-by-squaring.
#
Expand All @@ -113,12 +115,11 @@ def combine(c1, c2, l2, n, poly):
# after A before B.

m = Matrix.zero_operator(poly)
m = m.mul(m)
m = m.mul(m)
m = m.sqr().sqr()

M = Matrix.identity()
while l2:
m = m.mul(m)
m = m.sqr()
if l2 & 1:
M = m.mul(M)
l2 >>= 1
Expand All @@ -140,13 +141,13 @@ def combine(c1, c2, l2, n, poly):
break

b = M.multiply_vector(b, b)
M = M.mul(M)
M = M.sqr()

return c1

# Tools

def reverse_bits(x):
def reverse_bits(x: int) -> int:
# http://graphics.stanford.edu/~seander/bithacks.html#ReverseParallel
# http://stackoverflow.com/a/20918545
x = ((x & 0x55555555) << 1) | ((x & 0xAAAAAAAA) >> 1)
Expand All @@ -156,6 +157,6 @@ def reverse_bits(x):
x = ((x & 0x0000FFFF) << 16) | ((x & 0xFFFF0000) >> 16)
return x & 0xFFFFFFFF

def reciprocal(poly):
def reciprocal(poly: int) -> int:
''' Return the reciprocal polynomial of a reversed (lsbit-first) polynomial. '''
return poly << 1 & 0xffffffff | 1
9 changes: 4 additions & 5 deletions tests/test_crc32.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,18 @@ def to_bytes(data):

class Calc(unittest.TestCase):
def setUp(self):
self.crc32 = crc32.CRC32(0xedb88320)
self.crc32_reverse = crc32.CRC32Reverse(self.crc32)
self.crc32_reverse = crc32.CRC32Reverse(0xedb88320)

def test(self):
for c in calc:
b = to_bytes(c[2])

checksum = self.crc32.calc(b, c[1])
checksum = self.crc32_reverse.calc(b, c[1])
self.assertEqual(checksum, c[0])

self.assertSetEqual(self.crc32_reverse.rewind(b, c[0]), { c[1] })

if len(b) == 4:
self.assertSetEqual(self.crc32_reverse.find_reverse(*c[:2]), { c[2] })
self.assertSetEqual(self.crc32_reverse.find_reverse(*c[:2]), { bytes(c[2]) })

self.assertEqual(crc32.combine(c[1], checksum, len(b), 1009, 0xedb88320), self.crc32.calc(b * 1009, c[1]))
self.assertEqual(crc32.combine(c[1], checksum, len(b), 1009, 0xedb88320), self.crc32_reverse.calc(b * 1009, c[1]))
3 changes: 1 addition & 2 deletions tests/test_crc32_table_reverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@

class CRC32TableReverse(unittest.TestCase):
def test_table_reverse(self):
c = crc32.CRC32(0xedb88320)
crc32_reverse = crc32.CRC32Reverse(c)
crc32_reverse = crc32.CRC32Reverse(0xedb88320)

# reverse table for 0xedb88320 only
self.assertEqual(
Expand Down

0 comments on commit 30a9edf

Please sign in to comment.