Skip to content

Commit

Permalink
core: (irdl) ParamAttrConstraint can infer recursively (xdslproject#3477
Browse files Browse the repository at this point in the history
)

This seems to work, is there any reason not to do this?
  • Loading branch information
superlopuh authored and EdmundGoodman committed Dec 6, 2024
1 parent 0cf1d16 commit 2eb5338
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
52 changes: 52 additions & 0 deletions tests/irdl/test_attr_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

import pytest

from xdsl.dialects.builtin import StringAttr
from xdsl.ir import Attribute, ParametrizedAttribute
from xdsl.irdl import (
AllOf,
AnyAttr,
AttrConstraint,
BaseAttr,
ConstraintContext,
EqAttrConstraint,
ParamAttrConstraint,
ParameterDef,
VarConstraint,
eq,
irdl_attr_definition,
)

Expand Down Expand Up @@ -54,3 +57,52 @@ def test_attr_constraint_get_unique_base(
constraint: AttrConstraint, expected: type[Attribute] | None
):
assert constraint.get_unique_base() == expected


def test_param_attr_constraint_inference():
class BaseWrapAttr(ParametrizedAttribute):
name = "wrap"

inner: ParameterDef[Attribute]

@irdl_attr_definition
class WrapAttr(BaseWrapAttr): ...

constr = ParamAttrConstraint(
WrapAttr,
(
eq(
StringAttr("Hello"),
),
),
)

assert constr.can_infer(set())
assert constr.infer(ConstraintContext()) == WrapAttr((StringAttr("Hello"),))

var_constr = ParamAttrConstraint(
WrapAttr,
(
VarConstraint(
"T",
eq(
StringAttr("Hello"),
),
),
),
)

assert var_constr.can_infer({"T"})
assert var_constr.infer(ConstraintContext({"T": StringAttr("Hello")})) == WrapAttr(
(StringAttr("Hello"),)
)

base_constr = ParamAttrConstraint(
BaseWrapAttr,
(
eq(
StringAttr("Hello"),
),
),
)
assert not base_constr.can_infer(set())
12 changes: 12 additions & 0 deletions xdsl/irdl/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,18 @@ def verify(
for idx, param_constr in enumerate(self.param_constrs):
param_constr.verify(attr.parameters[idx], constraint_context)

def can_infer(self, constraint_names: set[str]) -> bool:
return is_runtime_final(self.base_attr) and all(
constr.can_infer(constraint_names) for constr in self.param_constrs
)

def infer(self, constraint_context: ConstraintContext) -> Attribute:
params = tuple(
constr.infer(constraint_context) for constr in self.param_constrs
)
attr = self.base_attr.new(params)
return attr

def get_resolved_variables(self) -> set[str]:
if not self.param_constrs:
return set()
Expand Down

0 comments on commit 2eb5338

Please sign in to comment.