diff --git a/aesara/scan/op.py b/aesara/scan/op.py index f0d5b2a4da..481eaf971d 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -161,8 +161,9 @@ def check_broadcast(v1, v2): which may wrongly be interpreted as broadcastable. """ - if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"): + if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType): return + msg = ( "The broadcast pattern of the output of scan (%s) is " "inconsistent with the one provided in `output_info` " @@ -173,13 +174,13 @@ def check_broadcast(v1, v2): "them consistent, e.g. using aesara.tensor." "{unbroadcast, specify_broadcastable}." ) - size = min(len(v1.broadcastable), len(v2.broadcastable)) + size = min(v1.type.ndim, v2.type.ndim) for n, (b1, b2) in enumerate( - zip(v1.broadcastable[-size:], v2.broadcastable[-size:]) + zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:]) ): if b1 != b2: - a1 = n + size - len(v1.broadcastable) + 1 - a2 = n + size - len(v2.broadcastable) + 1 + a1 = n + size - v1.type.ndim + 1 + a2 = n + size - v2.type.ndim + 1 raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2)) @@ -628,6 +629,7 @@ def validate_inner_graph(self): type_input = self.inner_inputs[inner_iidx].type type_output = self.inner_outputs[inner_oidx].type if ( + # TODO: Use the `Type` interface for this type_input.dtype != type_output.dtype or type_input.broadcastable != type_output.broadcastable ):