Skip to content

Commit

Permalink
Make CRT support non-coprime moduli
Browse files Browse the repository at this point in the history
Fixes #251
  • Loading branch information
mhostetter committed Feb 10, 2022
1 parent 005d77c commit 91f4651
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 50 deletions.
29 changes: 17 additions & 12 deletions galois/_polymorphic.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,27 +461,32 @@ def crt(remainders, moduli):
raise TypeError(f"Argument `moduli` must be a tuple or list of int or galois.Poly, not {moduli}.")
if not len(remainders) == len(moduli) >= 2:
raise ValueError(f"Arguments `remainders` and `moduli` must be the same length of at least 2, not {len(remainders)} and {len(moduli)}.")
if not are_coprime(*moduli):
raise ValueError(f"Elements of argument `moduli` must be pairwise coprime, {moduli} are not.")

# Iterate through the system of congruences reducing a pair of congruences into a
# single one. The answer to the final congruence solves all the congruences.
a1, m1 = remainders[0], moduli[0]
for a2, m2 in zip(remainders[1:], moduli[1:]):
# Use the Extended Euclidean Algorithm to determine: b1*m1 + b2*m2 = 1,
# where 1 is the GCD(m1, m2) because m1 and m2 are pairwise relatively coprime
_, b1, b2 = egcd(m1, m2)
# Use the Extended Euclidean Algorithm to determine: b1*m1 + b2*m2 = gcd(m1, m2).
d, b1, b2 = egcd(m1, m2)

if d == 1:
# The moduli (m1, m2) are coprime
x = a1*b2*m2 + a2*b1*m1 # Compute x through explicit construction
m1 = m1 * m2 # The new modulus
else:
# The moduli (m1, m2) are not coprime, however if a1 == b2 (mod d)
# then a unique solution still exists.
if not (a1 % d) == (a2 % d):
raise ValueError(f"Moduli {[m1, m2]} are not coprime and their residuals {[a1, a2]} are not equal modulo their GCD {d}, therefore a unique solution does not exist.")
x = (a1*b2*m2 + a2*b1*m1) // d # Compute x through explicit construction
m1 = (m1 * m2) // d # The new modulus

# Compute x through explicit construction
x = a1*b2*m2 + a2*b1*m1

m1 = m1 * m2 # The new modulus
a1 = x % m1 # The new equivalent remainder

# Align x to be within [0, prod(m))
x = x % prod(*moduli)
# At the end of the process x == a1 (mod m1) where a1 and m1 are the new/modified residual
# and remainder.

return x
return a1


###############################################################################
Expand Down
17 changes: 16 additions & 1 deletion scripts/generate_int_test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import sage
import numpy as np
from sage.all import Integer, xgcd, lcm, prod, isqrt, log
from sage.all import Integer, xgcd, lcm, prod, isqrt, log, crt

PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "tests")
FOLDER = os.path.join(PATH, "data")
Expand Down Expand Up @@ -120,3 +120,18 @@ def save_pickle(d, folder, name):
Z[i] = int(z)
d = {"X": X, "B": B, "Z": Z}
save_pickle(d, FOLDER, "ilog.pkl")

set_seed(SEED + 108)
N = [random.randint(2, 6) for _ in range(40)]
X = [[random.randint(0, 1000) for _ in range(N[i])] for i in range(20)] + [[random.randint(0, 1_000_000) for _ in range(N[20 + i])] for i in range(20)] # Remainder
Y = [[random.randint(10, 1000) for _ in range(N[i])] for i in range(20)] + [[random.randint(1000, 1_000_000) for _ in range(N[20 + i])] for i in range(20)] # Modulus
Z = [0,]*len(X) # The solution
for i in range(len(X)):
X[i] = [X[i][j] % Y[i][j] for j in range(len(X[i]))] # Ensure a is within [0, m)
try:
z = crt(X[i], Y[i])
Z[i] = int(z)
except:
Z[i] = None
d = {"X": X, "Y": Y, "Z": Z}
save_pickle(d, FOLDER, "crt.pkl")
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def power():
return read_pickle("power.pkl")


@pytest.fixture(scope="session")
def crt():
return read_pickle("crt.pkl")


@pytest.fixture(scope="session")
def isqrt():
return read_pickle("isqrt.pkl")
Expand Down
Binary file added tests/data/crt.pkl
Binary file not shown.
26 changes: 26 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
A pytest module to test the functions in _math.py.
"""
import pytest
import numpy as np

import galois

Expand Down Expand Up @@ -96,6 +97,31 @@ def test_pow(power):
assert galois.pow(X[i], E[i], M[i]) == Z[i]


def test_crt_exceptions():
with pytest.raises(TypeError):
galois.crt(np.array([0, 3, 4]), [3, 4, 5])
with pytest.raises(TypeError):
galois.crt([0, 3, 4], np.array([3, 4, 5]))
with pytest.raises(TypeError):
galois.crt([0, 3.0, 4], [3, 4, 5])
with pytest.raises(TypeError):
galois.crt([0, 3, 4], [3, 4.0, 5])
with pytest.raises(ValueError):
galois.crt([0, 3, 4], [3, 4, 5, 7])
with pytest.raises(ValueError):
galois.crt([0, 3, 4], [3, 4, 6])


def test_crt(crt):
X, Y, Z = crt["X"], crt["Y"], crt["Z"]
for i in range(len(X)):
if Z[i] is not None:
assert galois.crt(X[i], Y[i]) == Z[i]
else:
with pytest.raises(ValueError):
galois.crt(X[i], Y[i])


def test_isqrt_exceptions():
with pytest.raises(TypeError):
galois.isqrt(3.0)
Expand Down
37 changes: 0 additions & 37 deletions tests/test_number_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,43 +37,6 @@ def test_totatives():
assert len(galois.totatives(n)) == phi


def test_crt_exceptions():
with pytest.raises(TypeError):
galois.crt(np.array([0, 3, 4]), [3, 4, 5])
with pytest.raises(TypeError):
galois.crt([0, 3, 4], np.array([3, 4, 5]))
with pytest.raises(TypeError):
galois.crt([0, 3.0, 4], [3, 4, 5])
with pytest.raises(TypeError):
galois.crt([0, 3, 4], [3, 4.0, 5])
with pytest.raises(ValueError):
galois.crt([0, 3, 4], [3, 4, 5, 7])
with pytest.raises(ValueError):
galois.crt([0, 3, 4], [3, 4, 6])


def test_crt():
"""
Sage:
lut = []
for _ in range(20):
N = randint(2, 6)
a = [randint(0, 1_000) for _ in range(N)]
m = []
while len(m) < N:
mi = next_prime(randint(0, 1_000))
if mi not in m:
m.append(mi)
x = crt(a, m)
lut.append((a, m, x))
print(lut)
"""
LUT = [([975, 426, 300, 372, 596, 856], [457, 331, 521, 701, 71, 907], 1139408681764819), ([85, 653, 323, 655], [331, 479, 601, 191], 10711106463), ([589, 538, 501], [347, 541, 947], 155375738), ([788, 821, 414], [673, 331, 149], 20497003), ([269, 703, 436, 641, 616], [929, 293, 541, 467, 853], 39214084831996), ([270, 190], [173, 277], 15148), ([518, 809, 857, 118], [349, 821, 937, 157], 38154123633), ([711, 735, 1000, 426, 522], [149, 281, 293, 97, 37], 43994914384), ([104, 722, 168, 478], [193, 977, 211, 607], 23841886088), ([64, 160, 626, 702, 883], [877, 907, 251, 307, 839], 6612150797141), ([428, 570, 418, 346, 436], [467, 541, 373, 907, 179], 14825927170624), ([904, 14, 690, 585], [577, 907, 223, 967], 1713097171), ([238, 213, 368, 909, 455, 995], [137, 359, 947, 463, 937, 113], 1425637682359276), ([624, 95, 467, 472, 447, 849], [439, 79, 757, 269, 449, 293], 358511203165372), ([692, 245, 191, 101, 992, 267], [197, 367, 419, 139, 233, 593], 528226613934229), ([767, 794, 410], [331, 727, 359], 35029835), ([938, 992], [547, 17], 1485), ([337, 286, 308, 602, 386, 855], [67, 241, 167, 113, 211, 659], 36856037086592), ([681, 418], [997, 739], 630785), ([897, 343, 555, 245], [701, 89, 827, 379], 10082200693)]
for item in LUT:
a, m, x = item
assert galois.crt(a, m) == x


def test_carmichael_lambda_exceptions():
with pytest.raises(TypeError):
galois.carmichael_lambda(20.0)
Expand Down

0 comments on commit 91f4651

Please sign in to comment.