From 838ebbaff8d0610c24fe2ac398a20553091fabb6 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 19 Nov 2024 15:21:49 +0000 Subject: [PATCH 1/2] core: (irdl) ParamAttrConstraint can infer recursively --- tests/irdl/test_attr_constraint.py | 40 ++++++++++++++++++++++++++++++ xdsl/irdl/constraints.py | 10 ++++++++ 2 files changed, 50 insertions(+) diff --git a/tests/irdl/test_attr_constraint.py b/tests/irdl/test_attr_constraint.py index 2836e8eae2..c2647d1cfd 100644 --- a/tests/irdl/test_attr_constraint.py +++ b/tests/irdl/test_attr_constraint.py @@ -2,16 +2,19 @@ import pytest +from xdsl.dialects.builtin import StringAttr from xdsl.ir import Attribute, ParametrizedAttribute from xdsl.irdl import ( AllOf, AnyAttr, AttrConstraint, BaseAttr, + ConstraintContext, EqAttrConstraint, ParamAttrConstraint, ParameterDef, VarConstraint, + eq, irdl_attr_definition, ) @@ -54,3 +57,40 @@ def test_attr_constraint_get_unique_base( constraint: AttrConstraint, expected: type[Attribute] | None ): assert constraint.get_unique_base() == expected + + +def test_param_attr_constraint_inference(): + @irdl_attr_definition + class WrapAttr(ParametrizedAttribute): + name = "wrap" + + inner: ParameterDef[Attribute] + + constr = ParamAttrConstraint( + WrapAttr, + ( + eq( + StringAttr("Hello"), + ), + ), + ) + + assert constr.can_infer(set()) + assert constr.infer(ConstraintContext()) == WrapAttr((StringAttr("Hello"),)) + + var_constr = ParamAttrConstraint( + WrapAttr, + ( + VarConstraint( + "T", + eq( + StringAttr("Hello"), + ), + ), + ), + ) + + assert var_constr.can_infer({"T"}) + assert var_constr.infer(ConstraintContext({"T": StringAttr("Hello")})) == WrapAttr( + (StringAttr("Hello"),) + ) diff --git a/xdsl/irdl/constraints.py b/xdsl/irdl/constraints.py index 15956c3c78..31e7b08820 100644 --- a/xdsl/irdl/constraints.py +++ b/xdsl/irdl/constraints.py @@ -422,6 +422,16 @@ def verify( for idx, param_constr in enumerate(self.param_constrs): param_constr.verify(attr.parameters[idx], constraint_context) + def can_infer(self, constraint_names: set[str]) -> bool: + return all(constr.can_infer(constraint_names) for constr in self.param_constrs) + + def infer(self, constraint_context: ConstraintContext) -> Attribute: + params = tuple( + constr.infer(constraint_context) for constr in self.param_constrs + ) + attr = self.base_attr.new(params) + return attr + def get_resolved_variables(self) -> set[str]: if not self.param_constrs: return set() From 8afe11eeb53f39e1ad40c50db6e1dcff3b88737d Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 19 Nov 2024 16:18:38 +0000 Subject: [PATCH 2/2] check for runtime_final --- tests/irdl/test_attr_constraint.py | 16 ++++++++++++++-- xdsl/irdl/constraints.py | 4 +++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/tests/irdl/test_attr_constraint.py b/tests/irdl/test_attr_constraint.py index c2647d1cfd..de5739848c 100644 --- a/tests/irdl/test_attr_constraint.py +++ b/tests/irdl/test_attr_constraint.py @@ -60,12 +60,14 @@ def test_attr_constraint_get_unique_base( def test_param_attr_constraint_inference(): - @irdl_attr_definition - class WrapAttr(ParametrizedAttribute): + class BaseWrapAttr(ParametrizedAttribute): name = "wrap" inner: ParameterDef[Attribute] + @irdl_attr_definition + class WrapAttr(BaseWrapAttr): ... + constr = ParamAttrConstraint( WrapAttr, ( @@ -94,3 +96,13 @@ class WrapAttr(ParametrizedAttribute): assert var_constr.infer(ConstraintContext({"T": StringAttr("Hello")})) == WrapAttr( (StringAttr("Hello"),) ) + + base_constr = ParamAttrConstraint( + BaseWrapAttr, + ( + eq( + StringAttr("Hello"), + ), + ), + ) + assert not base_constr.can_infer(set()) diff --git a/xdsl/irdl/constraints.py b/xdsl/irdl/constraints.py index 31e7b08820..33bfc1062b 100644 --- a/xdsl/irdl/constraints.py +++ b/xdsl/irdl/constraints.py @@ -423,7 +423,9 @@ def verify( param_constr.verify(attr.parameters[idx], constraint_context) def can_infer(self, constraint_names: set[str]) -> bool: - return all(constr.can_infer(constraint_names) for constr in self.param_constrs) + return is_runtime_final(self.base_attr) and all( + constr.can_infer(constraint_names) for constr in self.param_constrs + ) def infer(self, constraint_context: ConstraintContext) -> Attribute: params = tuple(