Skip to content

Commit

Permalink
Progress
Browse files Browse the repository at this point in the history
  • Loading branch information
CalebBell committed Oct 26, 2024
1 parent a390bb1 commit 11371df
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions tests/test_numerics_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ def get_rtol(matrix):
return min(10 * cond * machine_eps,100*cond * machine_eps if cond > 1e8 else 1e-9)

def check_inv(matrix, rtol=None):
cond = np.linalg.cond(matrix)
just_return = False
try:
# This will fail for bad matrix (inconsistent size) inputs
cond = np.linalg.cond(matrix)
except:
just_return = True
py_fail = False
numpy_fail = False
try:
Expand All @@ -56,7 +61,7 @@ def check_inv(matrix, rtol=None):
except:
numpy_fail = True
if py_fail and not numpy_fail:
if cond > 1e14:
if not just_return and cond > 1e14:
# Let ill conditioned matrices pass
return
raise ValueError(f"Inconsistent failure states: Python Fail: {py_fail}, Numpy Fail: {numpy_fail}")
Expand All @@ -65,7 +70,8 @@ def check_inv(matrix, rtol=None):
if not py_fail and numpy_fail:
# We'll allow our inv to work with numbers closer to
return

if just_return:
return

# Convert result to numpy array if it isn't already
result = np.array(result)
Expand Down Expand Up @@ -98,9 +104,15 @@ def check_inv(matrix, rtol=None):
# to zero out anything too close to "zero" relative to the values used in the matrix
# This is very necessary, and was needed when testing on different CPU architectures
inv_norm = np.max(np.sum(np.abs(result), axis=1))
trivial_relative_to_norm = np.where(np.abs(result)/inv_norm < 100*thresh)
if cond < 1e10:
zero_thresh = 100*thresh
elif cond < 1e14:
zero_thresh = 1000*thresh
else:
zero_thresh = 10000*thresh
trivial_relative_to_norm = np.where(np.abs(result)/inv_norm < zero_thresh)
result[trivial_relative_to_norm] = 0.0
trivial_relative_to_norm = np.where(np.abs(expected)/inv_norm < 100*thresh)
trivial_relative_to_norm = np.where(np.abs(expected)/inv_norm < zero_thresh)
expected[trivial_relative_to_norm] = 0.0

if rtol is None:
Expand Down Expand Up @@ -506,7 +518,7 @@ def matrix_info(matrix):
[[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0]], # Nearly singular
[13.0, 14.0, 15.0, 16.0]], # Singular

[[1.0, 0.1, 0.1, 0.1],
[0.1, 2.0, 0.2, 0.2],
Expand Down

0 comments on commit 11371df

Please sign in to comment.