diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index a186f12fde..b23c61206b 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -1276,18 +1276,22 @@ def verify_(self) -> None: class TensorIgnoreSizeConstraint(VarConstraint[Attribute]): + + @staticmethod + def matches(attr: TensorType[Attribute], other: Attribute) -> bool: + return ( + isa(other, TensorType[Attribute]) + and len(attr.get_shape()) == len(other.get_shape()) + and attr.get_element_type() == other.get_element_type() + ) + def verify( self, attr: Attribute, constraint_context: ConstraintContext | None = None ) -> None: constraint_context = constraint_context or ConstraintContext() if self.name in constraint_context.variables: - if ( - isa(attr, TensorType[Attribute]) - and isinstance( - other := constraint_context.variables[self.name], TensorType - ) - and len(attr.get_shape()) == len(other.get_shape()) - and attr.get_element_type() == other.get_element_type() + if isa(attr, TensorType[Attribute]) and TensorIgnoreSizeConstraint.matches( + attr, constraint_context.variables[self.name] ): return super().verify(attr, constraint_context) @@ -1454,7 +1458,10 @@ def verify_(self) -> None: for i, res_type in enumerate(res_types): for j in range(unroll_factor * i, unroll_factor * (i + 1)): op_type = types[j] - if op_type != res_type: + if op_type != res_type and not ( + isa(op_type, TensorType[Attribute]) + and TensorIgnoreSizeConstraint.matches(op_type, res_type) + ): raise VerifyException( "stencil.return expected operand types to match the parent " f"stencil.apply result element types. Got {op_type} at index "