Skip to content

Commit

Permalink
use Blockwise instead of Elemwise
Browse files Browse the repository at this point in the history
  • Loading branch information
purna135 committed Nov 9, 2022
1 parent 90f6f6e commit 877d04d
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions aesara/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,29 +281,28 @@ def transform(var: "TensorVariable", client_node: Optional[Apply]) -> Variable:
return atleast_Nd(res, nd)

blocked_inputs = [transform(ipt, node) for ipt in node.inputs]

grad_signature = getattr(node.op, "gufunc_sig", None)
op = node.op

if grad_signature is None:
if isinstance(node.op, DimShuffle):
if isinstance(op, DimShuffle):
# remove the extra dimensions that
# we have added during op creation
new_order = [i for i in node.op.new_order if i != "x"]
new_order = [i for i in op.new_order if i != "x"]

# derive gufunc signature for DimShuffle
input_signature = tuple([f"a{i}" for i in range(len(new_order))])
output_signature = tuple([f"a{i}" for i in new_order])
grad_signature = ((input_signature,), (output_signature,))
elif isinstance(node.op, Elemwise):
elif isinstance(op, Elemwise):
op = op.scalar_op
input_signature = ((),) * len(blocked_inputs)
output_signature = ((),)
grad_signature = (input_signature, output_signature)
else:
raise ValueError(
f"'{node.op}' object has no attribute 'gufunc_sig'"
)
raise ValueError(f"'{op}' object has no attribute 'gufunc_sig'")

new_r = Blockwise(node.op, signature=grad_signature)(*blocked_inputs)
new_r = Blockwise(op, signature=grad_signature)(*blocked_inputs)
assert isinstance(new_r, Variable)
return new_r

Expand Down

0 comments on commit 877d04d

Please sign in to comment.