Skip to content

Commit

Permalink
Clean up some usage of the TensorType interface in Scan
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 20, 2022
1 parent 41aa8c2 commit 194ef41
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ def check_broadcast(v1, v2):
which may wrongly be interpreted as broadcastable.
"""
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
return

msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
Expand All @@ -173,13 +174,13 @@ def check_broadcast(v1, v2):
"them consistent, e.g. using aesara.tensor."
"{unbroadcast, specify_broadcastable}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
size = min(v1.type.ndim, v2.type.ndim)
for n, (b1, b2) in enumerate(
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
a1 = n + size - v1.type.ndim + 1
a2 = n + size - v2.type.ndim + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))


Expand Down Expand Up @@ -628,6 +629,7 @@ def validate_inner_graph(self):
type_input = self.inner_inputs[inner_iidx].type
type_output = self.inner_outputs[inner_oidx].type
if (
# TODO: Use the `Type` interface for this
type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
):
Expand Down

0 comments on commit 194ef41

Please sign in to comment.