Skip to content

Commit

Permalink
fix infer_shape_to_gufunc_sig
Browse files Browse the repository at this point in the history
  • Loading branch information
purna135 committed Mar 14, 2023
1 parent 1b04b0e commit 684914d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion aesara/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def infer_shape_to_gufunc_sig(node: Apply, fgraph: Optional["FunctionGraph"] = N
"""
op = node.op
in_shapes = tuple(
tuple(lscalar(f"i{s}") for s in range(inp.type.ndim)) for inp in node.inputs
tuple(lscalar(f"i{n}{s}") for s in range(inp.type.ndim))
for n, inp in enumerate(node.inputs)
)
out_shapes = op.infer_shape(fgraph, node, in_shapes)

Expand Down
10 changes: 6 additions & 4 deletions tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,21 +231,23 @@ def test_blockwise_cholesky_grad(shape):
def test_infer_shape_to_gufunc_sig():
y = at.extract_diag(at.matrix("x"))
res = infer_shape_to_gufunc_sig(y.owner)
assert res == ((("i0", "i1"),), (("o0",),))
assert res == ((("i00", "i01"),), (("o0",),))

y = at.diag(at.vector("x"))
res = infer_shape_to_gufunc_sig(y.owner)
assert res == ((("i0",),), (("i0", "i0"),))
assert res == ((("i00",),), (("i00", "i00"),))

# check signature of dot
# '(m,n),(n,p)->(m,p)'
y = at.dot(at.matrix("x"), at.matrix("y"))
res = infer_shape_to_gufunc_sig(y.owner)
assert res == ((("i0", "i1"), ("i0", "i1")), (("i0", "i1"),))
assert res == ((("i00", "i01"), ("i10", "i11")), (("i00", "i11"),))

# check signature of det
# '(m,m)->()'
y = at.nlinalg.det(at.matrix("x"))
res = infer_shape_to_gufunc_sig(y.owner)
assert res == ((("i0", "i1"),), ((),))
assert res == ((("i00", "i01"),), ((),))


def test_Blockwise_get_output_info():
Expand Down

0 comments on commit 684914d

Please sign in to comment.