From 684914de70b7bca6990f28a8650f1d425f3994f9 Mon Sep 17 00:00:00 2001 From: Purna Chandra Mansingh Date: Tue, 14 Mar 2023 20:19:35 +0530 Subject: [PATCH] fix infer_shape_to_gufunc_sig --- aesara/tensor/blockwise.py | 3 ++- tests/tensor/test_blockwise.py | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/aesara/tensor/blockwise.py b/aesara/tensor/blockwise.py index 014b25a7ec..fd79b56119 100644 --- a/aesara/tensor/blockwise.py +++ b/aesara/tensor/blockwise.py @@ -149,7 +149,8 @@ def infer_shape_to_gufunc_sig(node: Apply, fgraph: Optional["FunctionGraph"] = N """ op = node.op in_shapes = tuple( - tuple(lscalar(f"i{s}") for s in range(inp.type.ndim)) for inp in node.inputs + tuple(lscalar(f"i{n}{s}") for s in range(inp.type.ndim)) + for n, inp in enumerate(node.inputs) ) out_shapes = op.infer_shape(fgraph, node, in_shapes) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 55f12410a5..d8b95a1587 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -231,21 +231,23 @@ def test_blockwise_cholesky_grad(shape): def test_infer_shape_to_gufunc_sig(): y = at.extract_diag(at.matrix("x")) res = infer_shape_to_gufunc_sig(y.owner) - assert res == ((("i0", "i1"),), (("o0",),)) + assert res == ((("i00", "i01"),), (("o0",),)) y = at.diag(at.vector("x")) res = infer_shape_to_gufunc_sig(y.owner) - assert res == ((("i0",),), (("i0", "i0"),)) + assert res == ((("i00",),), (("i00", "i00"),)) # check signature of dot + # '(m,n),(n,p)->(m,p)' y = at.dot(at.matrix("x"), at.matrix("y")) res = infer_shape_to_gufunc_sig(y.owner) - assert res == ((("i0", "i1"), ("i0", "i1")), (("i0", "i1"),)) + assert res == ((("i00", "i01"), ("i10", "i11")), (("i00", "i11"),)) # check signature of det + # '(m,m)->()' y = at.nlinalg.det(at.matrix("x")) res = infer_shape_to_gufunc_sig(y.owner) - assert res == ((("i0", "i1"),), ((),)) + assert res == ((("i00", "i01"),), ((),)) def test_Blockwise_get_output_info():