From 194ef41c0c29c04f4d437e865ccb0c2079420c51 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 20 Oct 2022 18:55:04 -0500 Subject: [PATCH] Clean up some usage of the TensorType interface in Scan --- aesara/scan/op.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/aesara/scan/op.py b/aesara/scan/op.py index f0d5b2a4da..481eaf971d 100644 --- a/aesara/scan/op.py +++ b/aesara/scan/op.py @@ -161,8 +161,9 @@ def check_broadcast(v1, v2): which may wrongly be interpreted as broadcastable. """ - if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"): + if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType): return + msg = ( "The broadcast pattern of the output of scan (%s) is " "inconsistent with the one provided in `output_info` " @@ -173,13 +174,13 @@ def check_broadcast(v1, v2): "them consistent, e.g. using aesara.tensor." "{unbroadcast, specify_broadcastable}." ) - size = min(len(v1.broadcastable), len(v2.broadcastable)) + size = min(v1.type.ndim, v2.type.ndim) for n, (b1, b2) in enumerate( - zip(v1.broadcastable[-size:], v2.broadcastable[-size:]) + zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:]) ): if b1 != b2: - a1 = n + size - len(v1.broadcastable) + 1 - a2 = n + size - len(v2.broadcastable) + 1 + a1 = n + size - v1.type.ndim + 1 + a2 = n + size - v2.type.ndim + 1 raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2)) @@ -628,6 +629,7 @@ def validate_inner_graph(self): type_input = self.inner_inputs[inner_iidx].type type_output = self.inner_outputs[inner_oidx].type if ( + # TODO: Use the `Type` interface for this type_input.dtype != type_output.dtype or type_input.broadcastable != type_output.broadcastable ):