Skip to content

Commit

Permalink
Revert "Clean up some usage of the TensorType interface in Scan"
Browse files Browse the repository at this point in the history
This reverts commit 471657a.
  • Loading branch information
ferrine committed Jan 16, 2023
1 parent 4d8a6a6 commit 5facd79
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,8 @@ def check_broadcast(v1, v2):
which may wrongly be interpreted as broadcastable.
"""
if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
return

msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
Expand All @@ -169,13 +168,13 @@ def check_broadcast(v1, v2):
"them consistent, e.g. using pytensor.tensor."
"{unbroadcast, specify_broadcastable}."
)
size = min(v1.type.ndim, v2.type.ndim)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - v1.type.ndim + 1
a2 = n + size - v2.type.ndim + 1
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))


Expand Down Expand Up @@ -624,7 +623,6 @@ def validate_inner_graph(self):
type_input = self.inner_inputs[inner_iidx].type
type_output = self.inner_outputs[inner_oidx].type
if (
# TODO: Use the `Type` interface for this
type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
):
Expand Down

0 comments on commit 5facd79

Please sign in to comment.