Skip to content

Commit 4b6e2ad

Browse files
Tune float32 tests
1 parent 7a08ff3 commit 4b6e2ad

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

tests/tensor/test_slinalg.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,12 @@ def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal):
417417
unit_diagonal=unit_diagonal,
418418
)
419419

420-
np.testing.assert_allclose(x_pt, x_sp)
420+
np.testing.assert_allclose(
421+
x_pt,
422+
x_sp,
423+
atol=1e-8 if config.floatX == "float64" else 1e-4,
424+
rtol=1e-8 if config.floatX == "float64" else 1e-4,
425+
)
421426

422427
@pytest.mark.parametrize(
423428
"b_shape", [(5, 1), (5,), (5, 5)], ids=["b_col_vec", "b_vec", "b_matrix"]
@@ -426,6 +431,9 @@ def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal):
426431
@pytest.mark.parametrize("trans", [0, 1])
427432
@pytest.mark.parametrize("unit_diagonal", [True, False])
428433
def test_solve_triangular_grad(self, b_shape, lower, trans, unit_diagonal):
434+
if config.floatX == "float32":
435+
pytest.skip(reason="Not enough precision in float32 to get a good gradient")
436+
429437
rng = np.random.default_rng(utt.fetch_seed())
430438
A_val = rng.normal(size=(5, 5)).astype(config.floatX)
431439
b_val = rng.normal(size=b_shape).astype(config.floatX)

0 commit comments

Comments
 (0)