Skip to content

Commit

Permalink
Add more np.linalg.solve() unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhostetter committed Feb 10, 2022
1 parent edd2e0f commit f27480c
Show file tree
Hide file tree
Showing 19 changed files with 50 additions and 23 deletions.
25 changes: 24 additions & 1 deletion scripts/generate_field_test_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import sage
import numpy as np
from sage.all import GF, PolynomialRing, log, matrix
from sage.all import GF, PolynomialRing, log, matrix, vector

FIELD = None
SPARSE_SIZE = 20
Expand Down Expand Up @@ -439,6 +439,29 @@ def make_luts(field, sub_folder, seed, sparse=False):
d = {"X": X, "Z": Z}
save_pickle(d, folder, "matrix_determinant.pkl")

set_seed(seed + 207)
shapes = [(2,2), (2,2), (2,2), (3,3), (3,3), (3,3), (4,4), (4,4), (4,4), (5,5), (5,5), (5,5), (6,6), (6,6), (6,6)]
X = []
Y = []
Z = []
for i in range(len(shapes)):
while True:
x = randint_matrix(0, order, shapes[i])
x_orig = x.copy()
dtype = x.dtype
x = matrix(FIELD, [[F(e) for e in row] for row in x])
if x.rank() == shapes[i][0]:
break
X.append(x_orig)
y = randint_matrix(0, order, shapes[i][1]) # 1-D vector
Y.append(y)
y = vector(FIELD, [F(e) for e in y])
z = x.solve_right(y)
z = np.array([I(e) for e in z], dtype)
Z.append(z)
d = {"X": X, "Y": Y, "Z": Z}
save_pickle(d, folder, "matrix_solve.pkl")

###############################################################################
# Polynomial arithmetic
###############################################################################
Expand Down
13 changes: 13 additions & 0 deletions tests/fields/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,19 @@ def field_matrix_determinant(field_folder):
return d


@pytest.fixture(scope="session")
def field_matrix_solve(field_folder):
GF, folder = field_folder
with open(os.path.join(folder, "matrix_solve.pkl"), "rb") as f:
print(f"Loading {f}...")
d = pickle.load(f)
d["GF"] = GF
d["X"] = [GF(x) for x in d["X"]]
d["Y"] = [GF(y) for y in d["Y"]]
d["Z"] = [GF(z) for z in d["Z"]]
return d


###############################################################################
# Fixtures for arithmetic methods over finite fields
###############################################################################
Expand Down
Binary file added tests/fields/data/GF(109987^4)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(2)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(2147483647)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(2^100)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(2^2)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(2^3)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(2^32)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(2^8)/matrix_solve.pkl
Binary file not shown.
Binary file not shown.
Binary file added tests/fields/data/GF(31)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(3191)/matrix_solve.pkl
Binary file not shown.
Binary file not shown.
Binary file added tests/fields/data/GF(5)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(7)/matrix_solve.pkl
Binary file not shown.
Binary file added tests/fields/data/GF(7^3)/matrix_solve.pkl
Binary file not shown.
Binary file not shown.
35 changes: 13 additions & 22 deletions tests/fields/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,28 +223,6 @@ def test_matmul_2d_2d(field):
# assert array_equal(A @ B, np.matmul(A, B))


def test_solve_2d_1d(field):
dtype = random.choice(field.dtypes)
A = full_rank_matrix(field, 3, dtype)
b = field.Random(3, dtype=dtype)
x = np.linalg.solve(A, b)
assert type(x) is field
assert x.dtype == dtype
assert array_equal(A @ x, b)
assert x.shape == b.shape


def test_solve_2d_2d(field):
dtype = random.choice(field.dtypes)
A = full_rank_matrix(field, 3, dtype)
b = field.Random((3,5), dtype=dtype)
x = np.linalg.solve(A, b)
assert type(x) is field
assert x.dtype == dtype
assert array_equal(A @ x, b)
assert x.shape == b.shape


def full_rank_matrix(field, n, dtype):
A = field.Identity(n, dtype=dtype)
while True:
Expand Down Expand Up @@ -329,3 +307,16 @@ def test_matrix_determinant(field_matrix_determinant):
z = np.linalg.det(x)
assert z == Z[i]
assert type(z) is GF


def test_matrix_solve(field_matrix_solve):
GF, X, Y, Z = field_matrix_solve["GF"], field_matrix_solve["X"], field_matrix_solve["Y"], field_matrix_solve["Z"]

# np.linalg.solve(x, y) = z corresponds to X @ z = y
for i in range(len(X)):
dtype = random.choice(GF.dtypes)
x = X[i].astype(dtype)
y = Y[i].astype(dtype)
z = np.linalg.solve(x, y)
assert np.array_equal(z, Z[i])
assert type(z) is GF

0 comments on commit f27480c

Please sign in to comment.