Skip to content

Commit

Permalink
add test for Blockwise SolveTriangular
Browse files Browse the repository at this point in the history
  • Loading branch information
purna135 committed Mar 14, 2023
1 parent 684914d commit 3ed3497
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
1 change: 0 additions & 1 deletion aesara/tensor/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ class SolveTriangular(SolveBase):
"trans",
"unit_diagonal",
"check_finite",
"gufunc_sig",
)

def __init__(
Expand Down
31 changes: 30 additions & 1 deletion tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import aesara
import aesara.tensor as at
from aesara.configdefaults import config
from aesara.tensor.basic import Tri
from aesara.tensor.blockwise import (
Blockwise,
_calculate_shapes,
Expand All @@ -14,7 +15,7 @@
)
from aesara.tensor.math import Dot
from aesara.tensor.nlinalg import Det
from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
from aesara.tensor.type import TensorType
from tests import unittest_tools as utt
from tests.unittest_tools import check_infer_shape, verify_grad
Expand Down Expand Up @@ -100,6 +101,12 @@ def test_parse_input_dimensions(args, arg_vals, input_core_dims, output_core_dim
(np.zeros((5, 3, 3)),),
lambda x: np.linalg.det(x),
),
(
Tri(),
(at.scalar(), at.scalar(), at.scalar()),
(3, 4, 0),
lambda n, m, k: np.tri(n, m, k),
),
],
)
def test_Blockwise_perform(op, args, arg_vals, np_fn):
Expand Down Expand Up @@ -261,3 +268,25 @@ def test_Blockwise_get_output_info():
out_dtype, output_shapes, inputs = blk_op.get_output_info(a, b, c)

assert out_dtype == ["float64"]


@pytest.mark.parametrize(
"a_shape, b_shape",
[
(
(3, 3),
(3, 1),
)
],
)
def test_blockwise_SolveTriangular_grad(a_shape, b_shape):
rng = np.random.default_rng(utt.fetch_seed())
A_val = (rng.normal(size=a_shape) * 0.5 + np.eye(a_shape[-1])).astype(config.floatX)
b_val = rng.normal(size=b_shape).astype(config.floatX)

eps = None
if config.floatX == "float64":
eps = 2e-8

solve_op = Blockwise(SolveTriangular())
verify_grad(solve_op, [A_val, b_val], 3, rng, eps=eps)

0 comments on commit 3ed3497

Please sign in to comment.