@@ -417,7 +417,12 @@ def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal):
417
417
unit_diagonal = unit_diagonal ,
418
418
)
419
419
420
- np .testing .assert_allclose (x_pt , x_sp )
420
+ np .testing .assert_allclose (
421
+ x_pt ,
422
+ x_sp ,
423
+ atol = 1e-8 if config .floatX == "float64" else 1e-4 ,
424
+ rtol = 1e-8 if config .floatX == "float64" else 1e-4 ,
425
+ )
421
426
422
427
@pytest .mark .parametrize (
423
428
"b_shape" , [(5 , 1 ), (5 ,), (5 , 5 )], ids = ["b_col_vec" , "b_vec" , "b_matrix" ]
@@ -426,6 +431,9 @@ def test_correctness(self, b_shape: tuple[int], lower, trans, unit_diagonal):
426
431
@pytest .mark .parametrize ("trans" , [0 , 1 ])
427
432
@pytest .mark .parametrize ("unit_diagonal" , [True , False ])
428
433
def test_solve_triangular_grad (self , b_shape , lower , trans , unit_diagonal ):
434
+ if config .floatX == "float32" :
435
+ pytest .skip (reason = "Not enough precision in float32 to get a good gradient" )
436
+
429
437
rng = np .random .default_rng (utt .fetch_seed ())
430
438
A_val = rng .normal (size = (5 , 5 )).astype (config .floatX )
431
439
b_val = rng .normal (size = b_shape ).astype (config .floatX )
0 commit comments