From e5b5294a6aaf9a043e70a56e6964edf35bb60959 Mon Sep 17 00:00:00 2001 From: Maxim Kochurov Date: Thu, 29 Dec 2022 18:10:18 +0300 Subject: [PATCH] Revert "Clean up some usage of the TensorType interface in Scan" This reverts commit 471657a5ac3e9868797bcd779a4e635915999c5d. --- pytensor/scan/op.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 189b167c34..d7c5c2886a 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -156,9 +156,8 @@ def check_broadcast(v1, v2): which may wrongly be interpreted as broadcastable. """ - if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType): + if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"): return - msg = ( "The broadcast pattern of the output of scan (%s) is " "inconsistent with the one provided in `output_info` " @@ -169,13 +168,13 @@ def check_broadcast(v1, v2): "them consistent, e.g. using pytensor.tensor." "{unbroadcast, specify_broadcastable}." ) - size = min(v1.type.ndim, v2.type.ndim) + size = min(len(v1.broadcastable), len(v2.broadcastable)) for n, (b1, b2) in enumerate( - zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:]) + zip(v1.broadcastable[-size:], v2.broadcastable[-size:]) ): if b1 != b2: - a1 = n + size - v1.type.ndim + 1 - a2 = n + size - v2.type.ndim + 1 + a1 = n + size - len(v1.broadcastable) + 1 + a2 = n + size - len(v2.broadcastable) + 1 raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2)) @@ -624,7 +623,6 @@ def validate_inner_graph(self): type_input = self.inner_inputs[inner_iidx].type type_output = self.inner_outputs[inner_oidx].type if ( - # TODO: Use the `Type` interface for this type_input.dtype != type_output.dtype or type_input.broadcastable != type_output.broadcastable ):