From 2d1f865dce160992df01e6de13b365833b3eaab6 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Sat, 9 Apr 2022 05:17:57 +0100 Subject: [PATCH 01/36] Make `Data` generic over the contained type. --- src/xdsl/dialects/builtin.py | 34 ++++++++++++++++----------------- src/xdsl/ir.py | 20 +++++++++++-------- src/xdsl/parser.py | 4 ++-- src/xdsl/printer.py | 2 +- tests/attribute_builder_test.py | 30 +++++++++++++++-------------- tests/operation_builder_test.py | 8 ++++---- 6 files changed, 52 insertions(+), 46 deletions(-) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index 66eb219881..313b1a8d51 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -36,17 +36,17 @@ def __post_init__(self): @irdl_attr_definition -class StringAttr(Data): +class StringAttr(Data[str]): name = "string" - data: str @staticmethod - def parse(parser: Parser) -> StringAttr: + def parse_parameter(parser: Parser) -> str: data = parser.parse_str_literal() - return StringAttr(data) + return data - def print(self, printer: Printer) -> None: - printer.print_string(f'"{self.data}"') + @staticmethod + def print_parameter(data: str, printer: Printer) -> None: + printer.print_string(f'"{data}"') @staticmethod @builder @@ -87,17 +87,17 @@ def from_string_attr(data: StringAttr) -> FlatSymbolRefAttr: @irdl_attr_definition -class IntAttr(Data): +class IntAttr(Data[int]): name = "int" - data: int @staticmethod - def parse(parser: Parser) -> IntAttr: + def parse_parameter(parser: Parser) -> int: data = parser.parse_int_literal() - return IntAttr(data) + return data - def print(self, printer: Printer) -> None: - printer.print_string(f'{self.data}') + @staticmethod + def print_parameter(data: int, printer: Printer) -> None: + printer.print_string(f'{data}') @staticmethod @builder @@ -155,18 +155,18 @@ def from_params(value: Union[int, IntAttr], @irdl_attr_definition -class ArrayAttr(Data): +class ArrayAttr(Data[List[Attribute]]): name = "array" - data: List[Attribute] @staticmethod - def parse(parser) -> ArrayAttr: + def parse_parameter(parser: Parser) -> List[Attribute]: parser.parse_char("[") data = parser.parse_list(parser.parse_optional_attribute) parser.parse_char("]") - return ArrayAttr.get(data) + return data - def print(self, printer) -> None: + @staticmethod + def print_parameter(data: List[Attribute], printer: Printer) -> None: printer.print_string("[") printer.print_list(self.data, printer.print_attribute) printer.print_string("]") diff --git a/src/xdsl/ir.py b/src/xdsl/ir.py index 29dc80124c..2041e9b79e 100644 --- a/src/xdsl/ir.py +++ b/src/xdsl/ir.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import (Dict, List, Callable, Optional, Any, TYPE_CHECKING, +from typing import (Dict, Generic, List, Callable, Optional, Any, TYPE_CHECKING, TypeVar, Set, Union, Tuple) import typing from frozenlist import FrozenList @@ -192,20 +192,24 @@ def build(cls: typing.Type[AttrClass], *args) -> AttrClass: assert False +DataElement = TypeVar("DataElement") + + @dataclass(frozen=True) -class Data(Attribute): +class Data(Generic[DataElement], Attribute): """An attribute represented by a Python structure.""" + data: DataElement + @staticmethod @abstractmethod - def parse(parser: Parser) -> Data: - """Parse the attribute value.""" - ... + def parse_parameter(parser: Parser) -> DataElement: + """Parse the attribute parameter.""" + @staticmethod @abstractmethod - def print(self, printer: Printer) -> None: - """Print the attribute value.""" - ... + def print_parameter(data: DataElement, printer: Printer) -> None: + """Print the attribute parameter.""" @dataclass(frozen=True) diff --git a/src/xdsl/parser.py b/src/xdsl/parser.py index 7984a4a96b..b55b32f1e6 100644 --- a/src/xdsl/parser.py +++ b/src/xdsl/parser.py @@ -328,9 +328,9 @@ def parse_optional_attribute(self) -> Optional[Attribute]: return attr_def() if issubclass(attr_def, Data): - attr = attr_def.parse(self) + attr = attr_def.parse_parameter(self) self.parse_char(">") - return attr + return attr_def(attr) param_list = self.parse_list(self.parse_optional_attribute) self.parse_char(">") diff --git a/src/xdsl/printer.py b/src/xdsl/printer.py index 99b3706bdd..94c037bf65 100644 --- a/src/xdsl/printer.py +++ b/src/xdsl/printer.py @@ -267,7 +267,7 @@ def print_attribute(self, attribute: Attribute) -> None: if isinstance(attribute, Data): self.print(f'!{attribute.name}<') - attribute.print(self) + attribute.print_parameter(attribute.data, self) self.print(">") return diff --git a/tests/attribute_builder_test.py b/tests/attribute_builder_test.py index 68a0f4d6dd..e0f97f95eb 100644 --- a/tests/attribute_builder_test.py +++ b/tests/attribute_builder_test.py @@ -25,9 +25,8 @@ def test_no_builder_exception(): @irdl_attr_definition -class OneBuilderAttr(Data): +class OneBuilderAttr(Data[str]): name = "test.one_builder_attr" - param: str @staticmethod @builder @@ -35,10 +34,11 @@ def from_int(i: int) -> OneBuilderAttr: return OneBuilderAttr(str(i)) @staticmethod - def parse(parser: Parser) -> Data: + def parse_parameter(parser: Parser) -> str: pass - def print(self, printer: Printer) -> None: + @staticmethod + def print_parameter(data: str, printer: Printer) -> None: pass @@ -58,9 +58,8 @@ def test_one_builder_exception(): @irdl_attr_definition -class TwoBuildersAttr(Data): +class TwoBuildersAttr(Data[str]): name = "test.two_builder_attr" - param: str @staticmethod @builder @@ -73,10 +72,11 @@ def from_str(s: str) -> TwoBuildersAttr: return TwoBuildersAttr(s) @staticmethod - def parse(parser: Parser) -> Data: + def parse_parameter(parser: Parser) -> str: pass - def print(self, printer: Printer) -> None: + @staticmethod + def print_parameter(data: str, printer: Printer) -> None: pass @@ -101,7 +101,7 @@ def test_two_builders_bad_args(): @irdl_attr_definition -class BuilderDefaultArgAttr(Data): +class BuilderDefaultArgAttr(Data[str]): name = "test.builder_default_arg_attr" param: str @@ -111,10 +111,11 @@ def from_int(i: int, j: int = 0) -> BuilderDefaultArgAttr: return BuilderDefaultArgAttr(str(i)) @staticmethod - def parse(parser: Parser) -> Data: + def parse_parameter(parser: Parser) -> str: pass - def print(self, printer: Printer) -> None: + @staticmethod + def print_parameter(data: str, printer: Printer) -> None: pass @@ -129,7 +130,7 @@ def builder_default_arg_arg(): @irdl_attr_definition -class BuilderUnionArgAttr(Data): +class BuilderUnionArgAttr(Data[str]): name = "test.builder_union_arg_attr" param: str @@ -139,10 +140,11 @@ def from_int(i: typing.Union[str, int]) -> BuilderUnionArgAttr: return BuilderUnionArgAttr(str(i)) @staticmethod - def parse(parser: Parser) -> Data: + def parse_parameter(parser: Parser) -> str: pass - def print(self, printer: Printer) -> None: + @staticmethod + def print_parameter(data: str, printer: Printer) -> None: pass diff --git a/tests/operation_builder_test.py b/tests/operation_builder_test.py index 15557ea509..5fdb32819e 100644 --- a/tests/operation_builder_test.py +++ b/tests/operation_builder_test.py @@ -11,9 +11,8 @@ @irdl_attr_definition -class StringAttr(Data): +class StringAttr(Data[str]): name = "test.string_attr" - param: str @staticmethod @builder @@ -21,10 +20,11 @@ def from_int(i: int) -> StringAttr: return StringAttr(str(i)) @staticmethod - def parse(parser: Parser) -> Data: + def parse_parameter(parser: Parser) -> str: pass - def print(self, printer: Printer) -> None: + @staticmethod + def print_parameter(data: str, printer: Printer) -> None: pass From 12ad8a97cfdf6466830fe00bfdab55617b104562 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 13 Apr 2022 03:31:07 +0100 Subject: [PATCH 02/36] Add .vscode to .gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 46a0e78db4..4ad1accde0 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,5 @@ dmypy.json cython_debug/ # PyCharm IDE -*.idea/ \ No newline at end of file +*.idea/ +.vscode/settings.json From 1139c8cefb188639b326d4172a6372ecf0a13678 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 13 Apr 2022 03:28:41 +0100 Subject: [PATCH 03/36] Improve ParameterDef to have better type hints --- src/xdsl/dialects/builtin.py | 28 ++++---- src/xdsl/dialects/memref.py | 6 +- src/xdsl/ir.py | 9 +-- src/xdsl/irdl.py | 136 ++++++++++++++++++++++++++++------- src/xdsl/util.py | 6 +- 5 files changed, 137 insertions(+), 48 deletions(-) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index 313b1a8d51..e3384c1573 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -3,7 +3,7 @@ from xdsl.irdl import * from xdsl.ir import * -from typing import overload +from typing import Type if TYPE_CHECKING: from xdsl.parser import Parser @@ -57,7 +57,7 @@ def from_str(data: str) -> StringAttr: @irdl_attr_definition class SymbolNameAttr(ParametrizedAttribute): name = "symbol_name" - data = ParameterDef(StringAttr) + data: ParameterDef[StringAttr] @staticmethod @builder @@ -73,7 +73,7 @@ def from_string_attr(data: StringAttr) -> SymbolNameAttr: @irdl_attr_definition class FlatSymbolRefAttr(ParametrizedAttribute): name = "flat_symbol_ref" - data = ParameterDef(StringAttr) + data: ParameterDef[StringAttr] @staticmethod @builder @@ -108,7 +108,7 @@ def from_int(data: int) -> IntAttr: @irdl_attr_definition class IntegerType(ParametrizedAttribute): name = "integer_type" - width = ParameterDef(IntAttr) + width: ParameterDef[IntAttr] @staticmethod @builder @@ -129,8 +129,8 @@ class IndexType(ParametrizedAttribute): @irdl_attr_definition class IntegerAttr(ParametrizedAttribute): name = "integer" - value = ParameterDef(IntAttr) - typ = ParameterDef(AnyOf([IntegerType, IndexType])) + value: ParameterDef[IntAttr] + typ: ParameterDef[IntegerType | IndexType] @staticmethod @builder @@ -213,8 +213,8 @@ def from_type_list(types: List[Attribute]) -> TupleType: class VectorType(ParametrizedAttribute): name = "vector" - shape = ParameterDef(ArrayOfConstraint(IntegerAttr)) - element_type = ParameterDef(AnyAttr()) + shape: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(IntegerAttr)]] + element_type: ParameterDef[Attribute] def get_num_dims(self) -> int: return len(self.parameters[0].data) @@ -248,8 +248,8 @@ def from_params( class TensorType(ParametrizedAttribute): name = "tensor" - shape = ParameterDef(ArrayOfConstraint(IntegerAttr)) - element_type = ParameterDef(AnyAttr()) + shape: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(IntegerAttr)]] + element_type: ParameterDef[Attribute] def get_num_dims(self) -> int: return len(self.parameters[0].data) @@ -283,9 +283,9 @@ def from_params( class DenseIntOrFPElementsAttr(ParametrizedAttribute): name = "dense" # TODO add support for FPElements - type = ParameterDef(AnyOf([VectorType, TensorType])) + type: ParameterDef[VectorType | TensorType] # TODO add support for multi-dimensional data - data = ParameterDef(ArrayOfConstraint(IntegerAttr)) + data: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(IntegerAttr)]] @staticmethod @builder @@ -339,8 +339,8 @@ class UnitAttr(ParametrizedAttribute): class FunctionType(ParametrizedAttribute): name = "fun" - inputs = ParameterDef(ArrayOfConstraint(AnyAttr())) - outputs = ParameterDef(ArrayOfConstraint(AnyAttr())) + inputs: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(AnyAttr())]] + outputs: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(AnyAttr())]] @staticmethod @builder diff --git a/src/xdsl/dialects/memref.py b/src/xdsl/dialects/memref.py index 55ca935725..fe9ea2c4f9 100644 --- a/src/xdsl/dialects/memref.py +++ b/src/xdsl/dialects/memref.py @@ -26,11 +26,11 @@ def __post_init__(self): class MemRefType(ParametrizedAttribute): name = "memref" - shape = ParameterDef(ArrayOfConstraint(IntegerAttr)) - element_type = ParameterDef(AnyAttr()) + shape: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(IntegerAttr)]] + element_type: ParameterDef[Attribute] def get_num_dims(self) -> int: - return len(self.parameters[0].data) + return len(self.shape.data) def get_shape(self) -> List[int]: return [i.parameters[0].data for i in self.shape.data] diff --git a/src/xdsl/ir.py b/src/xdsl/ir.py index 2041e9b79e..61d7156510 100644 --- a/src/xdsl/ir.py +++ b/src/xdsl/ir.py @@ -175,7 +175,6 @@ def __hash__(self): AttrClass = TypeVar('AttrClass', bound='Attribute') -@dataclass(frozen=True) class Attribute(ABC): """ A compile-time value. @@ -183,8 +182,11 @@ class Attribute(ABC): on operations to give extra information. """ - name: str = field(default="", init=False) - """The attribute name should be a static field in the attribute classes.""" + @property + @abstractmethod + def name(self): + """The attribute name should be a static field in the attribute classes.""" + pass @classmethod def build(cls: typing.Type[AttrClass], *args) -> AttrClass: @@ -216,7 +218,6 @@ def print_parameter(data: DataElement, printer: Printer) -> None: class ParametrizedAttribute(Attribute): """An attribute parametrized by other attributes.""" - name: str = field(default="", init=False) parameters: List[Attribute] = field(default_factory=list) def __post_init__(self): diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index 73a69f2f12..f5698e51c5 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -1,11 +1,13 @@ from __future__ import annotations +from enum import Enum import inspect from dataclasses import dataclass, field from abc import ABC, abstractmethod -from typing import List, Tuple, Optional, Union, TypeVar +from typing import Annotated, List, Tuple, Optional, TypeAlias, Union, TypeVar, Any from inspect import isclass import typing +import types from xdsl.ir import Operation, Attribute, ParametrizedAttribute, SSAValue, Data, Region, Block from xdsl import util @@ -23,6 +25,10 @@ class VerifyException(DiagnosticException): ... +class IRDLAnnotations(Enum): + ParamDefAnnot = 1 + + @dataclass class AttrConstraint(ABC): """Constrain an attribute to a certain value.""" @@ -111,6 +117,18 @@ def verify(self, attr: Attribute) -> None: raise VerifyException(f"Unexpected attribute {attr}") +@dataclass() +class AndConstraint(AttrConstraint): + """Ensure that an attribute satisfies all the given constraints.""" + + attr_constrs: List[AttrConstraint] + """The list of constraints that are checked.""" + + def verify(self, attr: Attribute) -> None: + for attr_constr in self.attr_constrs: + attr_constr.verify(attr) + + @dataclass(init=False) class ParamAttrConstraint(AttrConstraint): """ @@ -146,6 +164,46 @@ def verify(self, attr: Attribute) -> None: param_constr.verify(attr.parameters[idx]) +def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: + if isinstance(irdl, AttrConstraint): + return irdl + + # Annotated case + # Each argument of the Annotated type correspond to a constraint to satisfy. + if typing.get_origin(irdl) == Annotated: + constraints = [] + for arg in typing.get_args(irdl): + # We should not try to convert IRDL annotations, which do not + # correspond to constraints + if isinstance(arg, IRDLAnnotations): + continue + constraints.append(irdl_to_attr_constraint(arg)) + if len(constraints) > 1: + return AndConstraint(constraints) + return constraints[0] + + # Attribute class case + # This is a coercion for an `BaseAttr`. + if isclass(irdl) and issubclass(irdl, Attribute): + return BaseAttr(irdl) + + # Union case + # This is a coercion for an `AnyOf` constraint. + if typing.get_origin(irdl) == types.UnionType: + constraints = [] + for arg in typing.get_args(irdl): + # We should not try to convert IRDL annotations, which do not + # correspond to constraints + if isinstance(arg, IRDLAnnotations): + continue + constraints.append(irdl_to_attr_constraint(arg)) + if len(constraints) > 1: + return AnyOf(constraints) + return constraints[0] + + raise ValueError(f"Unexpected irdl constraint: {irdl}") + + @dataclass class IRDLOption(ABC): """Additional option used in IRDL.""" @@ -645,18 +703,14 @@ def builder(cls, return type(cls.__name__, cls.__mro__, {**cls.__dict__, **new_attrs}) -@dataclass -class ParameterDef: - """An IRDL definition of an attribute parameter.""" - constr: AttrConstraint +_ParameterDefT = TypeVar("_ParameterDefT", bound=Attribute) - def __init__(self, typ: Union[Attribute, typing.Type[Attribute], - AttrConstraint]): - self.constr = attr_constr_coercion(typ) +ParameterDef: TypeAlias = Annotated[_ParameterDefT, + IRDLAnnotations.ParamDefAnnot] def irdl_attr_verify(attr: ParametrizedAttribute, - parameters: List[ParameterDef]): + parameters: List[AttrConstraint]): """Given an IRDL definition, verify that an attribute satisfies its invariants.""" if len(attr.parameters) != len(parameters): @@ -664,7 +718,16 @@ def irdl_attr_verify(attr: ParametrizedAttribute, f"{len(parameters)} parameters expected, got {len(attr.parameters)}" ) for idx, param_def in enumerate(parameters): - param_def.constr.verify(attr.parameters[idx]) + param = attr.parameters[idx] + assert isinstance(param, Attribute) + for arg in typing.get_args(parameters): + if isinstance(param, IRDLAnnotations): + continue + if not isinstance(arg, AttrConstraint): + raise Exception( + "Unexpected attribute constraint given to IRDL definition: {arg}" + ) + arg.verify(param) C = TypeVar('C', bound='Callable') @@ -740,39 +803,60 @@ def irdl_param_attr_definition( cls: typing.Type[AttributeType]) -> typing.Type[AttributeType]: """Decorator used on classes to define a new attribute definition.""" + # Get the fields from the class and its parents + clsdict = dict() + for parent_cls in cls.mro()[::-1]: + clsdict = {**clsdict, **parent_cls.__dict__} + + # IRDL parameters definitions parameters = [] - new_attrs = dict() - for field_name in cls.__dict__: - field_ = cls.__dict__[field_name] - if isinstance(field_, ParameterDef): - new_attrs[field_name] = property( - (lambda idx: lambda self: self.parameters[idx])( - len(parameters))) - parameters.append(field_) + # New fields and methods added to the attribute + new_fields = dict() + + # TODO(math-fehr): Check that "name" and "parameters" are not redefined + for field_name, field_type in typing.get_type_hints( + cls, include_extras=True).items(): + # name and parameters are reserved fields in IRDL + if field_name == "name" or field_name == "parameters": + continue + + # Throw an error if a parameter is not definied using `ParameterDef` + origin = typing.get_origin(field_type) + args = typing.get_args(field_type) + if origin != Annotated or IRDLAnnotations.ParamDefAnnot not in args: + raise ValueError( + f"In attribute {cls.__name__} definition: Parameter " + + f"definition {field_name} should be defined with " + + f"type `ParameterDef`, got type {field_type}.") - new_attrs["verify"] = lambda typ: irdl_attr_verify(typ, parameters) + # Add the accessors for the definition + new_fields[field_name] = property( + (lambda idx: lambda self: self.parameters[idx])(len(parameters))) + parameters.append(irdl_to_attr_constraint(field_type)) - if "verify" in cls.__dict__: - custom_verifier = cls.__dict__["verify"] + new_fields["verify"] = lambda typ: irdl_attr_verify(typ, parameters) + + if "verify" in clsdict: + custom_verifier = clsdict["verify"] def new_verifier(verifier, op): verifier(op) custom_verifier(op) - new_attrs["verify"] = ( + new_fields["verify"] = ( lambda verifier: lambda op: new_verifier(verifier, op))( - new_attrs["verify"]) + new_fields["verify"]) builders = irdl_get_builders(cls) if "build" in cls.__dict__: raise Exception( f'"build" method for {cls.__name__} is reserved for IRDL, and should not be defined.' ) - new_attrs["build"] = lambda *args: irdl_attr_builder(cls, builders, *args) + new_fields["build"] = lambda *args: irdl_attr_builder(cls, builders, *args) - return dataclass(frozen=True)(type(cls.__name__, (cls, ), { + return dataclass(frozen=True, init=False)(type(cls.__name__, (cls, ), { **cls.__dict__, - **new_attrs + **new_fields })) diff --git a/src/xdsl/util.py b/src/xdsl/util.py index 7992bda31c..ff11daa3af 100644 --- a/src/xdsl/util.py +++ b/src/xdsl/util.py @@ -1,6 +1,6 @@ import inspect from inspect import signature -from typing import Union, NoReturn, Callable, List, overload +from typing import Annotated, Any, TypeGuard, Union, NoReturn, Callable, List import typing from dataclasses import dataclass from xdsl.ir import Operation, SSAValue, BlockArgument, Block, Region, Attribute @@ -65,3 +65,7 @@ def is_satisfying_hint(arg, hint) -> bool: return False raise ValueError(f"is_satisfying_hint: unsupported type hint '{hint}'") + + +annotated_type = type(Annotated[int, 0]) +"""This is the type of an Annotated object.""" \ No newline at end of file From 1f308d4acbb298e175df4fffa3a620c404323186 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 14 Apr 2022 22:46:43 +0200 Subject: [PATCH 04/36] Set python version to 3.10 in the CI --- .github/workflows/python-app.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index f2a03be440..7b56545b66 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -12,10 +12,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: "3.10" - name: Install dependencies run: | python -m pip install --upgrade pip From 30cb3f49b3f692491cc9a309e6f199263b2f7978 Mon Sep 17 00:00:00 2001 From: Michel Steuwer Date: Mon, 25 Apr 2022 18:02:47 +0100 Subject: [PATCH 05/36] Moved to ArrayAttr(Data[List[A]]) --- pyproject.toml | 6 +- src/xdsl/diagnostic.py | 15 ++-- src/xdsl/dialects/builtin.py | 90 +++++++++---------- src/xdsl/dialects/memref.py | 28 +++--- src/xdsl/ir.py | 134 +++++++++++++++------------- src/xdsl/irdl.py | 163 ++++++++++++++++++++++------------- src/xdsl/parser.py | 2 +- src/xdsl/printer.py | 3 + 8 files changed, 249 insertions(+), 192 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b5a3c468d9..f7cc972f0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,4 +3,8 @@ requires = [ "setuptools>=42", "wheel" ] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" + +[tool.pyright] +reportImportCycles = false +reportUnnecessaryIsInstance = false \ No newline at end of file diff --git a/src/xdsl/diagnostic.py b/src/xdsl/diagnostic.py index 2376b030b2..ad1adda67e 100644 --- a/src/xdsl/diagnostic.py +++ b/src/xdsl/diagnostic.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Dict, List +from typing import Dict, List, Union, Type from dataclasses import dataclass, field from xdsl.ir import Block, Operation, Region @@ -18,10 +18,11 @@ def add_message(self, op: Operation, message: str) -> None: """Add a message to an operation.""" self.op_messages.setdefault(op, []).append(message) - def raise_exception(self, - message, - ir: Union[Operation, Block, Region], - exception_type=DiagnosticException) -> None: + def raise_exception( + self, + message: str, + ir: Union[Operation, Block, Region], + exception_type: Type[Exception] = DiagnosticException) -> None: """Raise an exception, that will also print all messages in the IR.""" from xdsl.printer import Printer f = StringIO() @@ -30,9 +31,9 @@ def raise_exception(self, if isinstance(toplevel, Operation): p.print_op(toplevel) elif isinstance(toplevel, Block): - p._print_named_block(toplevel) + p._print_named_block(toplevel) # type: ignore elif isinstance(toplevel, Region): - p._print_region(toplevel) + p._print_region(toplevel) # type: ignore else: assert "xDSL internal error: get_toplevel_object returned unknown construct" diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index e3384c1573..892161c7e2 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -3,7 +3,6 @@ from xdsl.irdl import * from xdsl.ir import * -from typing import Type if TYPE_CHECKING: from xdsl.parser import Parser @@ -146,8 +145,7 @@ def from_index_int_value(value: int) -> IntegerAttr: @staticmethod @builder - def from_params(value: Union[int, IntAttr], - typ: Union[int, Attribute]) -> IntegerAttr: + def from_params(value: int | IntAttr, typ: int | Attribute) -> IntegerAttr: value = IntAttr.build(value) if not isinstance(typ, IndexType): typ = IntegerType.build(typ) @@ -155,46 +153,47 @@ def from_params(value: Union[int, IntAttr], @irdl_attr_definition -class ArrayAttr(Data[List[Attribute]]): +class ArrayAttr(Data[List[A]]): name = "array" @staticmethod - def parse_parameter(parser: Parser) -> List[Attribute]: + def parse_parameter(parser: Parser) -> List[A]: parser.parse_char("[") data = parser.parse_list(parser.parse_optional_attribute) parser.parse_char("]") - return data + # the type system can't ensure that the elements are of type A + # and not just of type Attribute, therefore, the following cast + return cast(List[A], data) @staticmethod - def print_parameter(data: List[Attribute], printer: Printer) -> None: + def print_parameter(data: List[A], printer: Printer) -> None: printer.print_string("[") - printer.print_list(self.data, printer.print_attribute) + printer.print_list(data, printer.print_attribute) printer.print_string("]") @staticmethod @builder - def from_list(data: List[Attribute]) -> ArrayAttr: + def from_list(data: List[A]) -> ArrayAttr[A]: return ArrayAttr(data) -@dataclass -class ArrayOfConstraint(AttrConstraint): - """ - A constraint that enforces an ArrayData whose elements all satisfy - the elem_constr. - """ - elem_constr: AttrConstraint +# @dataclass +# class ArrayOfConstraint(AttrConstraint): +# """ +# A constraint that enforces an ArrayData whose elements all satisfy +# the elem_constr. +# """ +# elem_constr: AttrConstraint - def __init__(self, constr: Union[Attribute, typing.Type[Attribute], - AttrConstraint]): - self.elem_constr = attr_constr_coercion(constr) +# def __init__(self, constr: Attribute | Type[Attribute] | AttrConstraint): +# self.elem_constr = attr_constr_coercion(constr) - def verify(self, data: Data) -> None: - if not isinstance(data, ArrayAttr): - raise Exception(f"expected data ArrayData but got {data}") +# def verify(self, attr: Attribute) -> None: +# if not isinstance(attr, Data): +# raise Exception(f"expected data ArrayData but got {attr}") - for e in data.data: - self.elem_constr.verify(e) +# for e in cast(ArrayAttr[Attribute], attr).data: +# self.elem_constr.verify(e) @irdl_attr_definition @@ -213,20 +212,20 @@ def from_type_list(types: List[Attribute]) -> TupleType: class VectorType(ParametrizedAttribute): name = "vector" - shape: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(IntegerAttr)]] + shape: ParameterDef[ArrayAttr[IntegerAttr]] element_type: ParameterDef[Attribute] def get_num_dims(self) -> int: - return len(self.parameters[0].data) + return len(self.shape.data) def get_shape(self) -> List[int]: - return [i.parameters[0].data for i in self.shape.data] + return [i.value.data for i in self.shape.data] @staticmethod @builder def from_type_and_list( referenced_type: Attribute, - shape: List[Union[int, IntegerAttr]] = None) -> VectorType: + shape: Optional[List[int | IntegerAttr]] = None) -> VectorType: if shape is None: shape = [1] return VectorType([ @@ -238,7 +237,7 @@ def from_type_and_list( @builder def from_params( referenced_type: Attribute, - shape: ArrayAttr = ArrayAttr.from_list( + shape: ArrayAttr[IntegerAttr] = ArrayAttr.from_list( [IntegerAttr.from_int_and_width(1, 64)]) ) -> VectorType: return VectorType([shape, referenced_type]) @@ -248,20 +247,20 @@ def from_params( class TensorType(ParametrizedAttribute): name = "tensor" - shape: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(IntegerAttr)]] + shape: ParameterDef[ArrayAttr[IntegerAttr]] element_type: ParameterDef[Attribute] def get_num_dims(self) -> int: - return len(self.parameters[0].data) + return len(self.shape.data) def get_shape(self) -> List[int]: - return [i.parameters[0].data for i in self.shape.data] + return [i.value.data for i in self.shape.data] @staticmethod @builder def from_type_and_list( referenced_type: Attribute, - shape: List[Union[int, IntegerAttr]] = None) -> TensorType: + shape: Optional[Sequence[int | IntegerAttr]] = None) -> TensorType: if shape is None: shape = [1] return TensorType([ @@ -273,7 +272,7 @@ def from_type_and_list( @builder def from_params( referenced_type: Attribute, - shape: ArrayAttr = ArrayAttr.from_list( + shape: ArrayAttr[IntegerAttr] = ArrayAttr.from_list( [IntegerAttr.from_int_and_width(1, 64)]) ) -> TensorType: return TensorType([shape, referenced_type]) @@ -285,20 +284,20 @@ class DenseIntOrFPElementsAttr(ParametrizedAttribute): # TODO add support for FPElements type: ParameterDef[VectorType | TensorType] # TODO add support for multi-dimensional data - data: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(IntegerAttr)]] + data: ParameterDef[ArrayAttr[IntegerAttr]] @staticmethod @builder - def from_int_list(type: Union[VectorType, TensorType], data: List[int], - bitwidth) -> DenseIntOrFPElementsAttr: + def from_int_list(type: VectorType | TensorType, data: List[int], + bitwidth: int) -> DenseIntOrFPElementsAttr: data_attr = [IntegerAttr.from_int_and_width(d, bitwidth) for d in data] return DenseIntOrFPElementsAttr([type, ArrayAttr.from_list(data_attr)]) @staticmethod @builder def from_list( - type: Union[VectorType, TensorType], - data: List[Union[int, IntegerAttr]]) -> DenseIntOrFPElementsAttr: + type: VectorType | TensorType, + data: List[int] | List[IntegerAttr]) -> DenseIntOrFPElementsAttr: element_type = type.element_type # Only use the element_type if the passed data is an int, o/w use the IntegerAttr data_attr = [(IntegerAttr.from_params(d, element_type) if isinstance( @@ -309,7 +308,7 @@ def from_list( @builder def vector_from_list( data: List[int], - typ: Union[IntegerType, IndexType]) -> DenseIntOrFPElementsAttr: + typ: IntegerType | IndexType) -> DenseIntOrFPElementsAttr: t = VectorType.from_type_and_list(typ, [len(data)]) return DenseIntOrFPElementsAttr.from_list(t, data) @@ -317,7 +316,7 @@ def vector_from_list( @builder def tensor_from_list( data: List[int], - typ: Union[IntegerType, IndexType]) -> DenseIntOrFPElementsAttr: + typ: IntegerType | IndexType) -> DenseIntOrFPElementsAttr: t = TensorType.from_type_and_list(typ, [len(data)]) return DenseIntOrFPElementsAttr.from_list(t, data) @@ -339,8 +338,8 @@ class UnitAttr(ParametrizedAttribute): class FunctionType(ParametrizedAttribute): name = "fun" - inputs: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(AnyAttr())]] - outputs: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(AnyAttr())]] + inputs: ParameterDef[ArrayAttr[Attribute]] + outputs: ParameterDef[ArrayAttr[Attribute]] @staticmethod @builder @@ -352,7 +351,8 @@ def from_lists(inputs: List[Attribute], @staticmethod @builder - def from_attrs(inputs: ArrayAttr, outputs: ArrayAttr) -> Attribute: + def from_attrs(inputs: ArrayAttr[Attribute], + outputs: ArrayAttr[Attribute]) -> Attribute: return FunctionType([inputs, outputs]) @@ -367,7 +367,7 @@ def ops(self) -> List[Operation]: return self.regions[0].blocks[0].ops @staticmethod - def from_region_or_ops(ops: Union[List[Operation], Region]) -> ModuleOp: + def from_region_or_ops(ops: List[Operation] | Region) -> ModuleOp: if isinstance(ops, list): region = Region.from_operation_list(ops) elif isinstance(ops, Region): diff --git a/src/xdsl/dialects/memref.py b/src/xdsl/dialects/memref.py index fe9ea2c4f9..ae6f4a7628 100644 --- a/src/xdsl/dialects/memref.py +++ b/src/xdsl/dialects/memref.py @@ -26,20 +26,20 @@ def __post_init__(self): class MemRefType(ParametrizedAttribute): name = "memref" - shape: ParameterDef[Annotated[ArrayAttr, ArrayOfConstraint(IntegerAttr)]] + shape: ParameterDef[ArrayAttr[IntegerAttr]] element_type: ParameterDef[Attribute] def get_num_dims(self) -> int: return len(self.shape.data) def get_shape(self) -> List[int]: - return [i.parameters[0].data for i in self.shape.data] + return [i.value.data for i in self.shape.data] @staticmethod @builder def from_type_and_list( referenced_type: Attribute, - shape: List[Union[int, IntegerAttr]] = None) -> MemRefType: + shape: Optional[List[int | IntegerAttr]] = None) -> MemRefType: if shape is None: shape = [1] return MemRefType([ @@ -51,7 +51,7 @@ def from_type_and_list( @builder def from_params( referenced_type: Attribute, - shape: ArrayAttr = ArrayAttr.from_list( + shape: ArrayAttr[IntegerAttr] = ArrayAttr.from_list( [IntegerAttr.from_int_and_width(1, 64)]) ) -> MemRefType: return MemRefType([shape, referenced_type]) @@ -76,8 +76,8 @@ def verify_(self): raise Exception("expected an index for each dimension") @staticmethod - def get(ref: Union[SSAValue, Operation], - indices: List[Union[SSAValue, Operation]]) -> Load: + def get(ref: SSAValue | Operation, + indices: List[SSAValue | Operation]) -> Load: return Load.build(operands=[ref, indices], result_types=[SSAValue.get(ref).typ.element_type]) @@ -98,8 +98,8 @@ def verify_(self): raise Exception("expected an index for each dimension") @staticmethod - def get(value: Union[Operation, SSAValue], ref: Union[Operation, SSAValue], - indices: List[Union[Operation, SSAValue]]) -> Store: + def get(value: Operation | SSAValue, ref: Operation | SSAValue, + indices: List[Operation | SSAValue]) -> Store: return Store.build(operands=[value, ref, indices]) @@ -120,7 +120,7 @@ class Alloc(Operation): @staticmethod def get(return_type: Attribute, alignment: int, - shape: List[Union[int, IntegerAttr]] = None) -> Alloc: + shape: Optional[List[int | IntegerAttr]] = None) -> Alloc: if shape is None: shape = [1] return Alloc.build( @@ -148,7 +148,7 @@ class Alloca(Operation): @staticmethod def get(return_type: Attribute, alignment: int, - shape: List[Union[int, IntegerAttr]] = None) -> Alloca: + shape: Optional[List[int | IntegerAttr]] = None) -> Alloca: if shape is None: shape = [1] return Alloca.build( @@ -165,7 +165,7 @@ class Dealloc(Operation): memref = OperandDef(MemRefType) @staticmethod - def get(operand: Union[Operation, SSAValue]) -> Dealloc: + def get(operand: Operation | SSAValue) -> Dealloc: return Dealloc.build(operands=[operand]) @@ -185,7 +185,7 @@ def verify_(self) -> None: "expected 'name' attribute to be a FlatSymbolRefAttr") @staticmethod - def get(name, return_type: Attribute) -> GetGlobal: + def get(name: str, return_type: Attribute) -> GetGlobal: return GetGlobal.build( result_types=[return_type], attributes={"name": FlatSymbolRefAttr.build(name)}) @@ -217,10 +217,10 @@ def verify_(self) -> None: ) @staticmethod - def get(sym_name: Union[str, StringAttr], + def get(sym_name: str | StringAttr, typ: Attribute, initial_value: Optional[Attribute], - sym_visibility="private") -> Global: + sym_visibility: str = "private") -> Global: return Global.build( attributes={ "sym_name": sym_name, diff --git a/src/xdsl/ir.py b/src/xdsl/ir.py index 61d7156510..411c797ad8 100644 --- a/src/xdsl/ir.py +++ b/src/xdsl/ir.py @@ -1,8 +1,10 @@ from __future__ import annotations + from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import (Dict, Generic, List, Callable, Optional, Any, TYPE_CHECKING, - TypeVar, Set, Union, Tuple) +from typing import (TYPE_CHECKING, Any, Callable, Dict, Generic, List, + Protocol, Optional, Sequence, Set, Type, TypeVar, Union, + cast) import typing from frozenlist import FrozenList @@ -11,37 +13,35 @@ from xdsl.parser import Parser from xdsl.printer import Printer -OperationType = TypeVar('OperationType', bound='Operation', covariant=True) +OpT = TypeVar('OpT', bound='Operation') @dataclass class MLContext: """Contains structures for operations/attributes registration.""" - _registeredOps: Dict[str, - typing.Type[Operation]] = field(default_factory=dict) - _registeredAttrs: Dict[str, typing.Type[Attribute]] = field( - default_factory=dict) + _registeredOps: Dict[str, Type[Operation]] = field(default_factory=dict) + _registeredAttrs: Dict[str, Type[Attribute]] = field(default_factory=dict) - def register_op(self, op: typing.Type[Operation]) -> None: + def register_op(self, op: Type[Operation]) -> None: """Register an operation definition. Operation names should be unique.""" if op.name in self._registeredOps: raise Exception(f"Operation {op.name} has already been registered") self._registeredOps[op.name] = op - def register_attr(self, attr: typing.Type[Attribute]) -> None: + def register_attr(self, attr: Type[Attribute]) -> None: """Register an attribute definition. Attribute names should be unique.""" if attr.name in self._registeredAttrs: raise Exception( f"Attribute {attr.name} has already been registered") self._registeredAttrs[attr.name] = attr - def get_op(self, name: str) -> typing.Type[Operation]: + def get_op(self, name: str) -> Type[Operation]: """Get an operation class from its name.""" if name not in self._registeredOps: raise Exception(f"Operation {name} is not registered") return self._registeredOps[name] - def get_attr(self, name: str) -> typing.Type[Attribute]: + def get_attr(self, name: str) -> Type[Attribute]: """Get an attribute class from its name.""" if name not in self._registeredAttrs: raise Exception(f"Attribute {name} is not registered") @@ -126,10 +126,11 @@ def __repr__(self) -> str: return f"OpResult(typ={repr(self.typ)}, num_uses={repr(len(self.uses))}" + \ f", op_name={repr(self.op.name)}, result_index={repr(self.result_index)}, name={repr(self.name)})" - def __eq__(self, other): + def __eq__(self, other: OpResult) -> bool: return self is other - def __hash__(self): + # This might be problematic, as the superclass is not hashable ... + def __hash__(self) -> int: return id(self) @@ -152,10 +153,10 @@ def __repr__(self) -> str: f", block={block_repr}," \ " index={repr(self.index)}" - def __eq__(self, other): + def __eq__(self, other: BlockArgument) -> bool: return self is other - def __hash__(self): + def __hash__(self) -> int: return id(self) @@ -168,11 +169,11 @@ class ErasedSSAValue(SSAValue): old_value: SSAValue - def __hash__(self): + def __hash__(self) -> int: return hash(id(self)) -AttrClass = TypeVar('AttrClass', bound='Attribute') +A = TypeVar('A', bound='Attribute') class Attribute(ABC): @@ -182,14 +183,11 @@ class Attribute(ABC): on operations to give extra information. """ - @property - @abstractmethod - def name(self): - """The attribute name should be a static field in the attribute classes.""" - pass + name: str = field(default="", init=False) + """The attribute name should be a static field in the attribute classes.""" @classmethod - def build(cls: typing.Type[AttrClass], *args) -> AttrClass: + def build(cls: Type[A], *args: Any) -> A: """Create a new attribute using one of the builder defined in IRDL.""" assert False @@ -284,12 +282,13 @@ def __post_init__(self): assert (isinstance(self.name, str)) @staticmethod - def with_result_types(op: Any, - operands: Optional[List[SSAValue]] = None, - result_types: Optional[List[Attribute]] = None, - attributes: Optional[Dict[str, Attribute]] = None, - successors: Optional[List[Block]] = None, - regions: Optional[List[Region]] = None) -> Operation: + def with_result_types( + op: Any, + operands: Optional[Sequence[SSAValue]] = None, + result_types: Optional[Sequence[Attribute]] = None, + attributes: Optional[Dict[str, Attribute]] = None, + successors: Optional[Sequence[Block]] = None, + regions: Optional[Sequence[Region]] = None) -> Operation: operation = op() if operands is not None: @@ -312,22 +311,23 @@ def with_result_types(op: Any, return operation @classmethod - def create(cls: typing.Type[OperationType], - operands: Optional[List[SSAValue]] = None, - result_types: Optional[List[Attribute]] = None, + def create(cls: Type[OpT], + operands: Optional[Sequence[SSAValue]] = None, + result_types: Optional[Sequence[Attribute]] = None, attributes: Optional[Dict[str, Attribute]] = None, - successors: Optional[List[Block]] = None, - regions: Optional[List[Region]] = None) -> OperationType: - return Operation.with_result_types(cls, operands, result_types, - attributes, successors, regions) + successors: Optional[Sequence[Block]] = None, + regions: Optional[Sequence[Region]] = None) -> OpT: + op = Operation.with_result_types(cls, operands, result_types, + attributes, successors, regions) + return cast(OpT, op) @classmethod - def build(cls: typing.Type[OperationType], - operands: List[Any] = None, - result_types: List[Any] = None, - attributes: Dict[str, Any] = None, - successors: List[Any] = None, - regions: List[Any] = None) -> OperationType: + def build(cls: Type[OpT], + operands: Optional[List[Any]] = None, + result_types: Optional[List[Any]] = None, + attributes: Optional[Dict[str, Any]] = None, + successors: Optional[List[Any]] = None, + regions: Optional[List[Any]] = None) -> OpT: """Create a new operation using builders.""" ... @@ -383,10 +383,9 @@ def print(self, printer: Printer): return printer.print_op_with_default_format(self) def clone_without_regions( - self: OperationType, + self: OpT, value_mapper: Optional[Dict[SSAValue, SSAValue]] = None, - block_mapper: Optional[Dict[Block, - Block]] = None) -> OperationType: + block_mapper: Optional[Dict[Block, Block]] = None) -> OpT: """Clone an operation, with empty regions instead.""" if value_mapper is None: value_mapper = {} @@ -411,11 +410,9 @@ def clone_without_regions( value_mapper[self.results[idx]] = result return cloned_op - def clone( - self: OperationType, - value_mapper: Optional[Dict[SSAValue, SSAValue]] = None, - block_mapper: Optional[Dict[Block, - Block]] = None) -> OperationType: + def clone(self: OpT, + value_mapper: Optional[Dict[SSAValue, SSAValue]] = None, + block_mapper: Optional[Dict[Block, Block]] = None) -> OpT: """Clone an operation with all its regions and operations in them.""" if value_mapper is None: value_mapper = {} @@ -426,7 +423,9 @@ def clone( region.clone_into(op.regions[idx], 0, value_mapper, block_mapper) return op - def erase(self, safe_erase=True, drop_references=True) -> None: + def erase(self, + safe_erase: bool = True, + drop_references: bool = True) -> None: """ Erase the operation, and remove all its references to other operations. If safe_erase is specified, check that the operation results are not used. @@ -468,7 +467,8 @@ def __hash__(self) -> int: class Block: """A sequence of operations""" - _args: FrozenList[BlockArgument] = field(default_factory=list, init=False) + _args: FrozenList[BlockArgument] = field(default_factory=FrozenList, + init=False) """The basic block arguments.""" ops: List[Operation] = field(default_factory=list, init=False) @@ -504,19 +504,25 @@ def from_arg_types(arg_types: List[Attribute]) -> Block: return b @staticmethod - def from_ops(ops: List[Operation], arg_types: List[Attribute] = None): + def from_ops(ops: List[Operation], + arg_types: Optional[List[Attribute]] = None): b = Block() if arg_types is not None: - b._args = [ + b._args = FrozenList([ BlockArgument(typ, b, index) for index, typ in enumerate(arg_types) - ] + ]) + b._args.freeze() b.add_ops(ops) return b + class BlockCallback(Protocol): + + def __call__(self, *args: BlockArgument) -> List[Operation]: + ... + @staticmethod - def from_callable(block_arg_types: List[Attribute], - f: Callable[[BlockArgument, ...], List[Operation]]): + def from_callable(block_arg_types: List[Attribute], f: BlockCallback): b = Block.from_arg_types(block_arg_types) b.add_ops(f(*b.args)) return b @@ -634,7 +640,9 @@ def detach_op(self, op: Union[int, Operation]) -> Operation: self.ops = self.ops[:op_idx] + self.ops[op_idx + 1:] return op - def erase_op(self, op: Union[int, Operation], safe_erase=True) -> None: + def erase_op(self, + op: Union[int, Operation], + safe_erase: bool = True) -> None: """ Erase an operation from the block. If safe_erase is True, check that the operation has no uses. @@ -664,7 +672,7 @@ def drop_all_references(self) -> None: for op in self.ops: op.drop_all_references() - def erase(self, safe_erase=True) -> None: + def erase(self, safe_erase: bool = True) -> None: """ Erase the block, and remove all its references to other operations. If safe_erase is specified, check that no operation results are used outside the block. @@ -731,9 +739,9 @@ def get(arg: Region | List[Block] | List[Operation]) -> Region: if len(arg) == 0: return Region.from_operation_list([]) if isinstance(arg[0], Block): - return Region.from_block_list(arg) + return Region.from_block_list(cast(List[Block], arg)) if isinstance(arg[0], Operation): - return Region.from_operation_list(arg) + return Region.from_operation_list(cast(List[Operation], arg)) raise TypeError(f"Can't build a region with argument {arg}") @property @@ -815,7 +823,9 @@ def detach_block(self, block: Union[int, Block]) -> Block: self.blocks = self.blocks[:block_idx] + self.blocks[block_idx + 1:] return block - def erase_block(self, block: Union[int, Block], safe_erase=True) -> None: + def erase_block(self, + block: Union[int, Block], + safe_erase: bool = True) -> None: """ Erase a block from the region. If safe_erase is True, check that the block has no uses. diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index f5698e51c5..a68bb49626 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -1,18 +1,19 @@ from __future__ import annotations -from enum import Enum import inspect -from dataclasses import dataclass, field +import types from abc import ABC, abstractmethod -from typing import Annotated, List, Tuple, Optional, TypeAlias, Union, TypeVar, Any +from dataclasses import dataclass, field +from enum import Enum from inspect import isclass -import typing -import types +from typing import (Annotated, Any, Callable, Dict, List, Optional, Sequence, + Tuple, Type, TypeAlias, TypeVar, TypeGuard, Union, cast, + get_args, get_origin, get_type_hints, ForwardRef) -from xdsl.ir import Operation, Attribute, ParametrizedAttribute, SSAValue, Data, Region, Block from xdsl import util - from xdsl.diagnostic import Diagnostic, DiagnosticException +from xdsl.ir import (Attribute, Block, Data, Operation, ParametrizedAttribute, + Region, SSAValue) def error(op: Operation, msg: str): @@ -59,7 +60,7 @@ def verify(self, attr: Attribute) -> None: class BaseAttr(AttrConstraint): """Constrain an attribute to be of a given base type.""" - attr: typing.Type[Attribute] + attr: Type[Attribute] """The expected attribute base type.""" def verify(self, attr: Attribute) -> None: @@ -69,8 +70,8 @@ def verify(self, attr: Attribute) -> None: def attr_constr_coercion( - attr: Union[Attribute, typing.Type[Attribute], AttrConstraint] -) -> AttrConstraint: + attr: Union[Attribute, Type[Attribute], + AttrConstraint]) -> AttrConstraint: """ Attributes are coerced into EqAttrConstraints, and Attribute types are coerced into BaseAttr. @@ -79,6 +80,7 @@ def attr_constr_coercion( return EqAttrConstraint(attr) if isclass(attr) and issubclass(attr, Attribute): return BaseAttr(attr) + assert (isinstance(attr, AttrConstraint)) return attr @@ -98,9 +100,8 @@ class AnyOf(AttrConstraint): attr_constrs: List[AttrConstraint] """The list of constraints that are checked.""" - def __init__(self, - attr_constrs: List[Union[Attribute, typing.Type[Attribute], - AttrConstraint]]): + def __init__(self, attr_constrs: Sequence[Attribute | Type[Attribute] + | AttrConstraint]): self.attr_constrs = [ attr_constr_coercion(constr) for constr in attr_constrs ] @@ -118,7 +119,7 @@ def verify(self, attr: Attribute) -> None: @dataclass() -class AndConstraint(AttrConstraint): +class AllOf(AttrConstraint): """Ensure that an attribute satisfies all the given constraints.""" attr_constrs: List[AttrConstraint] @@ -129,6 +130,35 @@ def verify(self, attr: Attribute) -> None: attr_constr.verify(attr) +@dataclass +class DataListAttr(AttrConstraint): + """ + A constraint that enforces that the elements of an attribute with a list attributes all satisfy the elem_constr. + + """ + elem_constr: AttrConstraint + + def __init__(self, constr: Attribute | Type[Attribute] | AttrConstraint): + self.elem_constr = attr_constr_coercion(constr) + + def verify(self, attr: Attribute) -> None: + + def is_list_data(val: Attribute) -> TypeGuard[Data[List[Attribute]]]: + if not isinstance(val, Data): + return False + list: Any = val.data # type: ignore + if not isinstance(list, List): + return False + return all(isinstance(a, Attribute) for a in list) # type: ignore + + if not is_list_data(attr): + raise Exception( + f"expected data Data[List[Attribute]] but got {attr}") + + for e in attr.data: + self.elem_constr.verify(e) + + @dataclass(init=False) class ParamAttrConstraint(AttrConstraint): """ @@ -136,14 +166,14 @@ class ParamAttrConstraint(AttrConstraint): and also constrain its parameters with additional constraints. """ - base_attr: typing.Type[Attribute] + base_attr: Type[Attribute] """The base attribute type.""" param_constrs: List[AttrConstraint] """The attribute parameter constraints""" - def __init__(self, base_attr: typing.Type[Attribute], - param_constrs: List[Union[Attribute, typing.Type[Attribute], + def __init__(self, base_attr: Type[Attribute], + param_constrs: List[Union[Attribute, Type[Attribute], AttrConstraint]]): self.base_attr = base_attr self.param_constrs = [ @@ -153,8 +183,10 @@ def __init__(self, base_attr: typing.Type[Attribute], def verify(self, attr: Attribute) -> None: assert isinstance(attr, ParametrizedAttribute) if not isinstance(attr, self.base_attr): + # the type checker concludes that attr has type 'Never', therefore the cast + name = cast(Attribute, attr).name raise VerifyException( - f"Base attribute {self.base_attr.name} expected, but got {attr.name}" + f"Base attribute {self.base_attr.name} expected, but got {name}" ) if len(self.param_constrs) != len(attr.parameters): raise VerifyException( @@ -170,16 +202,16 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: # Annotated case # Each argument of the Annotated type correspond to a constraint to satisfy. - if typing.get_origin(irdl) == Annotated: - constraints = [] - for arg in typing.get_args(irdl): + if get_origin(irdl) == Annotated: + constraints: List[AttrConstraint] = [] + for arg in get_args(irdl): # We should not try to convert IRDL annotations, which do not # correspond to constraints if isinstance(arg, IRDLAnnotations): continue constraints.append(irdl_to_attr_constraint(arg)) if len(constraints) > 1: - return AndConstraint(constraints) + return AllOf(constraints) return constraints[0] # Attribute class case @@ -187,11 +219,26 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: if isclass(irdl) and issubclass(irdl, Attribute): return BaseAttr(irdl) + # yapf: disable + if (origin := get_origin(irdl)) and ( # we deal with a generic class + issubclass(origin, Data)) and ( # that is a subclass of Data + len(origin.__orig_bases__) == 1) and ( # with one superclass + data := origin.__orig_bases__[0]) and ( # called `data' + arg := get_args(data)[0]) and ( # whose argument is `arg' + get_origin(arg) == list) and ( # which is a list + get_args(arg)[0].__bound__ == + ForwardRef("Attribute") ): # and the element are attributes + args = get_args(irdl) + assert (len(args) == 1) + elem_constr = irdl_to_attr_constraint(args[0]) + return DataListAttr(elem_constr) + # yapf: enable + # Union case # This is a coercion for an `AnyOf` constraint. - if typing.get_origin(irdl) == types.UnionType: - constraints = [] - for arg in typing.get_args(irdl): + if get_origin(irdl) == types.UnionType: + constraints: List[AttrConstraint] = [] + for arg in get_args(irdl): # We should not try to convert IRDL annotations, which do not # correspond to constraints if isinstance(arg, IRDLAnnotations): @@ -251,8 +298,7 @@ class OperandDef(OperandOrResultDef): constr: AttrConstraint """The operand constraint.""" - def __init__(self, typ: Union[Attribute, typing.Type[Attribute], - AttrConstraint]): + def __init__(self, typ: Attribute | Type[Attribute] | AttrConstraint): self.constr = attr_constr_coercion(typ) @@ -273,8 +319,7 @@ class ResultDef(OperandOrResultDef): constr: AttrConstraint """The result constraint.""" - def __init__(self, typ: Union[Attribute, typing.Type[Attribute], - AttrConstraint]): + def __init__(self, typ: Attribute | Type[Attribute] | AttrConstraint): self.constr = attr_constr_coercion(typ) @@ -311,10 +356,9 @@ class AttributeDef: constr: AttrConstraint """The attribute constraint.""" - data: typing.Any + data: Any - def __init__(self, typ: Union[Attribute, typing.Type[Attribute], - AttrConstraint]): + def __init__(self, typ: Union[Attribute, Type[Attribute], AttrConstraint]): self.constr = attr_constr_coercion(typ) @@ -322,8 +366,7 @@ def __init__(self, typ: Union[Attribute, typing.Type[Attribute], class OptAttributeDef(AttributeDef): """An IRDL attribute definition for an optional attribute.""" - def __init__(self, typ: Union[Attribute, typing.Type[Attribute], - AttrConstraint]): + def __init__(self, typ: Union[Attribute, Type[Attribute], AttrConstraint]): super().__init__(typ) @@ -510,21 +553,23 @@ def irdl_build_attribute(irdl_def: AttrConstraint, result) -> Attribute: raise Exception(f"builder expected an attribute, got {result}") -OpT = TypeVar('OpT', bound='Operation') +OpT = TypeVar('OpT', bound=Operation) -def irdl_op_builder(cls: typing.Type[OpT], operands: List, +def irdl_op_builder(cls: Type[OpT], operands: List[Any], operand_defs: List[Tuple[str, OperandDef]], - res_types: List, res_defs: List[Tuple[str, ResultDef]], - attributes: typing.Dict[str, typing.Any], - attr_defs: typing.Dict[str, AttributeDef], successors, - regions, options) -> OpT: + res_types: List[Any], res_defs: List[Tuple[str, + ResultDef]], + attributes: Dict[str, Any], attr_defs: Dict[str, + AttributeDef], + successors, regions, options) -> OpT: """Builder for an irdl operation.""" # We need irdl to define DenseIntOrFPElementsAttr, but here we need # DenseIntOrFPElementsAttr. # So we have a circular dependency that we solve by importing in this function. - from xdsl.dialects.builtin import DenseIntOrFPElementsAttr, IntegerAttr, VectorType, IntegerType, i32 + from xdsl.dialects.builtin import (DenseIntOrFPElementsAttr, IntegerAttr, + IntegerType, VectorType, i32) # Build operands by forwarding the values to SSAValue.get if len(operand_defs) != len(operands): @@ -600,11 +645,7 @@ def irdl_op_builder(cls: typing.Type[OpT], operands: List, regions=regions) -OperationType = TypeVar("OperationType", bound=Operation) - - -def irdl_op_definition( - cls: typing.Type[OperationType]) -> typing.Type[OperationType]: +def irdl_op_definition(cls: Type[OpT]) -> Type[OpT]: """Decorator used on classes to define a new operation definition.""" assert issubclass( @@ -703,10 +744,9 @@ def builder(cls, return type(cls.__name__, cls.__mro__, {**cls.__dict__, **new_attrs}) -_ParameterDefT = TypeVar("_ParameterDefT", bound=Attribute) +_A = TypeVar("_A", bound=Attribute) -ParameterDef: TypeAlias = Annotated[_ParameterDefT, - IRDLAnnotations.ParamDefAnnot] +ParameterDef: TypeAlias = Annotated[_A, IRDLAnnotations.ParamDefAnnot] def irdl_attr_verify(attr: ParametrizedAttribute, @@ -720,7 +760,7 @@ def irdl_attr_verify(attr: ParametrizedAttribute, for idx, param_def in enumerate(parameters): param = attr.parameters[idx] assert isinstance(param, Attribute) - for arg in typing.get_args(parameters): + for arg in get_args(parameters): if isinstance(param, IRDLAnnotations): continue if not isinstance(arg, AttrConstraint): @@ -730,7 +770,7 @@ def irdl_attr_verify(attr: ParametrizedAttribute, arg.verify(param) -C = TypeVar('C', bound='Callable') +C = TypeVar('C', bound=Callable[..., Any]) def builder(f: C) -> C: @@ -742,7 +782,7 @@ def builder(f: C) -> C: return f -def irdl_get_builders(cls) -> List[typing.Callable]: +def irdl_get_builders(cls) -> List[Callable[..., Any]]: builders = [] for field_name in cls.__dict__: field_ = cls.__dict__[field_name] @@ -754,7 +794,7 @@ def irdl_get_builders(cls) -> List[typing.Callable]: def irdl_attr_try_builder(builder, *args): - params_dict = typing.get_type_hints(builder) + params_dict = get_type_hints(builder) builder_params = inspect.signature(builder).parameters params = [params_dict[param.name] for param in builder_params.values()] defaults = [param.default for param in builder_params.values()] @@ -782,7 +822,7 @@ def irdl_attr_builder(cls, builders, *args): T = TypeVar('T') -def irdl_data_definition(cls: typing.Type[T]) -> typing.Type[T]: +def irdl_data_definition(cls: Type[T]) -> Type[T]: builders = irdl_get_builders(cls) if "build" in cls.__dict__: raise Exception( @@ -796,11 +836,10 @@ def irdl_data_definition(cls: typing.Type[T]) -> typing.Type[T]: })) -AttributeType = TypeVar("AttributeType", bound=ParametrizedAttribute) +PA = TypeVar("PA", bound=ParametrizedAttribute) -def irdl_param_attr_definition( - cls: typing.Type[AttributeType]) -> typing.Type[AttributeType]: +def irdl_param_attr_definition(cls: Type[PA]) -> Type[PA]: """Decorator used on classes to define a new attribute definition.""" # Get the fields from the class and its parents @@ -814,15 +853,15 @@ def irdl_param_attr_definition( new_fields = dict() # TODO(math-fehr): Check that "name" and "parameters" are not redefined - for field_name, field_type in typing.get_type_hints( - cls, include_extras=True).items(): + for field_name, field_type in get_type_hints(cls, + include_extras=True).items(): # name and parameters are reserved fields in IRDL if field_name == "name" or field_name == "parameters": continue # Throw an error if a parameter is not definied using `ParameterDef` - origin = typing.get_origin(field_type) - args = typing.get_args(field_type) + origin = get_origin(field_type) + args = get_args(field_type) if origin != Annotated or IRDLAnnotations.ParamDefAnnot not in args: raise ValueError( f"In attribute {cls.__name__} definition: Parameter " + @@ -860,7 +899,7 @@ def new_verifier(verifier, op): })) -def irdl_attr_definition(cls: typing.Type[T]) -> typing.Type[T]: +def irdl_attr_definition(cls: Type[T]) -> Type[T]: if issubclass(cls, ParametrizedAttribute): return irdl_param_attr_definition(cls) if issubclass(cls, Data): diff --git a/src/xdsl/parser.py b/src/xdsl/parser.py index b55b32f1e6..3a4e21a66f 100644 --- a/src/xdsl/parser.py +++ b/src/xdsl/parser.py @@ -147,7 +147,7 @@ def parse_string(self, contents: List[str]) -> bool: def parse_list(self, parse_optional_one: Callable[[], Optional[T]], - delimiter=",") -> List[T]: + delimiter: str = ",") -> List[T]: assert (len(delimiter) <= 1) res = [] one = parse_optional_one() diff --git a/src/xdsl/printer.py b/src/xdsl/printer.py index 94c037bf65..55a5f11847 100644 --- a/src/xdsl/printer.py +++ b/src/xdsl/printer.py @@ -55,6 +55,9 @@ def print_string(self, text) -> None: self._current_column += len(lines[-1]) print(text, end='', file=self.stream) + def print_string(self, string: str) -> None: + self.print(string) + def _add_message_on_next_line(self, message: str, begin_pos: int, end_pos: int): """Add a message that will be displayed on the next line.""" From 192b3ccad53b83c35a982efa48676b929eba4643 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Tue, 26 Apr 2022 05:01:36 +0100 Subject: [PATCH 06/36] Remove commented code --- src/xdsl/dialects/builtin.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index 892161c7e2..8e71a9d655 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -177,25 +177,6 @@ def from_list(data: List[A]) -> ArrayAttr[A]: return ArrayAttr(data) -# @dataclass -# class ArrayOfConstraint(AttrConstraint): -# """ -# A constraint that enforces an ArrayData whose elements all satisfy -# the elem_constr. -# """ -# elem_constr: AttrConstraint - -# def __init__(self, constr: Attribute | Type[Attribute] | AttrConstraint): -# self.elem_constr = attr_constr_coercion(constr) - -# def verify(self, attr: Attribute) -> None: -# if not isinstance(attr, Data): -# raise Exception(f"expected data ArrayData but got {attr}") - -# for e in cast(ArrayAttr[Attribute], attr).data: -# self.elem_constr.verify(e) - - @irdl_attr_definition class TupleType(ParametrizedAttribute): name = "tuple" From 47c768817c30ae12cd90c73c2f45b61a43e99b40 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 27 Apr 2022 03:22:56 +0100 Subject: [PATCH 07/36] Make Data an ABC class --- src/xdsl/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xdsl/ir.py b/src/xdsl/ir.py index 411c797ad8..708fde0f20 100644 --- a/src/xdsl/ir.py +++ b/src/xdsl/ir.py @@ -196,7 +196,7 @@ def build(cls: Type[A], *args: Any) -> A: @dataclass(frozen=True) -class Data(Generic[DataElement], Attribute): +class Data(Generic[DataElement], Attribute, ABC): """An attribute represented by a Python structure.""" data: DataElement From 2b6b3648ff4b3843dc1e5a9a836787a0b6dcd6e5 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 27 Apr 2022 05:31:37 +0100 Subject: [PATCH 08/36] Make IRDL extensible with new generic coercions --- src/xdsl/dialects/builtin.py | 29 ++++++++++++++++- src/xdsl/irdl.py | 60 ++++++++++-------------------------- 2 files changed, 45 insertions(+), 44 deletions(-) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index 8e71a9d655..abcd03c4ad 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -152,8 +152,26 @@ def from_params(value: int | IntAttr, typ: int | Attribute) -> IntegerAttr: return IntegerAttr([value, typ]) +@dataclass +class ArrayOfConstraint(AttrConstraint): + """ + A constraint that enforces an ArrayData whose elements all satisfy + the elem_constr. + """ + elem_constr: AttrConstraint + + def __init__(self, constr: Attribute | Type[Attribute] | AttrConstraint): + self.elem_constr = attr_constr_coercion(constr) + + def verify(self, attr: Attribute) -> None: + if not isinstance(attr, Data): + raise Exception(f"expected data ArrayData but got {attr}") + for e in cast(ArrayAttr[Attribute], attr).data: + self.elem_constr.verify(e) + + @irdl_attr_definition -class ArrayAttr(Data[List[A]]): +class ArrayAttr(Data[List[A]], IRDLGenericCoercion): name = "array" @staticmethod @@ -171,6 +189,15 @@ def print_parameter(data: List[A], printer: Printer) -> None: printer.print_list(data, printer.print_attribute) printer.print_string("]") + @staticmethod + def generic_constraint_coercion(args: tuple[Any]) -> AttrConstraint: + if len(args) == 0: + return ArrayOfConstraint(irdl_to_attr_constraint(args[0])) + if len(args) == 1: + return ArrayOfConstraint(AnyAttr()) + raise TypeError(f"Attribute ArrayAttr expects at most type" + f" parameter, but {len(args)} were given") + @staticmethod @builder def from_list(data: List[A]) -> ArrayAttr[A]: diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index a68bb49626..c206879299 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -130,35 +130,6 @@ def verify(self, attr: Attribute) -> None: attr_constr.verify(attr) -@dataclass -class DataListAttr(AttrConstraint): - """ - A constraint that enforces that the elements of an attribute with a list attributes all satisfy the elem_constr. - - """ - elem_constr: AttrConstraint - - def __init__(self, constr: Attribute | Type[Attribute] | AttrConstraint): - self.elem_constr = attr_constr_coercion(constr) - - def verify(self, attr: Attribute) -> None: - - def is_list_data(val: Attribute) -> TypeGuard[Data[List[Attribute]]]: - if not isinstance(val, Data): - return False - list: Any = val.data # type: ignore - if not isinstance(list, List): - return False - return all(isinstance(a, Attribute) for a in list) # type: ignore - - if not is_list_data(attr): - raise Exception( - f"expected data Data[List[Attribute]] but got {attr}") - - for e in attr.data: - self.elem_constr.verify(e) - - @dataclass(init=False) class ParamAttrConstraint(AttrConstraint): """ @@ -196,6 +167,20 @@ def verify(self, attr: Attribute) -> None: param_constr.verify(attr.parameters[idx]) +class IRDLGenericCoercion(ABC): + """ + Defines a coercion between a generic Attribute type and an attribute constraint. + """ + + @staticmethod + @abstractmethod + def generic_constraint_coercion(args: tuple[Any]) -> AttrConstraint: + """ + Given the generic parameters passed to the generic attribute type, + return the corresponding attribute constraint. + """ + + def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: if isinstance(irdl, AttrConstraint): return irdl @@ -219,20 +204,9 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: if isclass(irdl) and issubclass(irdl, Attribute): return BaseAttr(irdl) - # yapf: disable - if (origin := get_origin(irdl)) and ( # we deal with a generic class - issubclass(origin, Data)) and ( # that is a subclass of Data - len(origin.__orig_bases__) == 1) and ( # with one superclass - data := origin.__orig_bases__[0]) and ( # called `data' - arg := get_args(data)[0]) and ( # whose argument is `arg' - get_origin(arg) == list) and ( # which is a list - get_args(arg)[0].__bound__ == - ForwardRef("Attribute") ): # and the element are attributes - args = get_args(irdl) - assert (len(args) == 1) - elem_constr = irdl_to_attr_constraint(args[0]) - return DataListAttr(elem_constr) - # yapf: enable + origin = get_origin(irdl) + if issubclass(origin, IRDLGenericCoercion): + return origin.generic_constraint_coercion(get_args(irdl)) # Union case # This is a coercion for an `AnyOf` constraint. From 38d73ccb92d8de977697d783fdc2225fb566817d Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 27 Apr 2022 18:24:52 +0100 Subject: [PATCH 09/36] Fix ArrayAttr generic coercion --- src/xdsl/dialects/builtin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index abcd03c4ad..b9eb189c36 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -191,9 +191,9 @@ def print_parameter(data: List[A], printer: Printer) -> None: @staticmethod def generic_constraint_coercion(args: tuple[Any]) -> AttrConstraint: - if len(args) == 0: - return ArrayOfConstraint(irdl_to_attr_constraint(args[0])) if len(args) == 1: + return ArrayOfConstraint(irdl_to_attr_constraint(args[0])) + if len(args) == 0: return ArrayOfConstraint(AnyAttr()) raise TypeError(f"Attribute ArrayAttr expects at most type" f" parameter, but {len(args)} were given") From 65ee8f05197f5a01b4d1f656ce8933c3e2327d28 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 27 Apr 2022 18:38:37 +0100 Subject: [PATCH 10/36] Rename IRDLGenericCoercion --- src/xdsl/dialects/builtin.py | 2 +- src/xdsl/irdl.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index b9eb189c36..318ef76586 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -171,7 +171,7 @@ def verify(self, attr: Attribute) -> None: @irdl_attr_definition -class ArrayAttr(Data[List[A]], IRDLGenericCoercion): +class ArrayAttr(GenericData[List[A]]): name = "array" @staticmethod diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index c206879299..369eed7fa5 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -167,9 +167,13 @@ def verify(self, attr: Attribute) -> None: param_constr.verify(attr.parameters[idx]) -class IRDLGenericCoercion(ABC): +_DataElement = TypeVar("_DataElement") + + +@dataclass(frozen=True) +class GenericData(Data[_DataElement], ABC): """ - Defines a coercion between a generic Attribute type and an attribute constraint. + A Data with type parameters. """ @staticmethod @@ -205,7 +209,7 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: return BaseAttr(irdl) origin = get_origin(irdl) - if issubclass(origin, IRDLGenericCoercion): + if issubclass(origin, GenericData): return origin.generic_constraint_coercion(get_args(irdl)) # Union case From 90895b6a6d2a0b8786d13fec378e3f9f4019bbac Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 27 Apr 2022 19:08:35 +0100 Subject: [PATCH 11/36] Add error message when GenericData should be used --- src/xdsl/irdl.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index 369eed7fa5..b8dd10d78d 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -212,6 +212,12 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: if issubclass(origin, GenericData): return origin.generic_constraint_coercion(get_args(irdl)) + if issubclass(origin, Data): + raise ValueError( + f"Generic `Data` type '{origin.name}' cannot be converted to " + "an attribute constraint. Consider making it inherit from " + "`GenericData` instead of `Data`") + # Union case # This is a coercion for an `AnyOf` constraint. if get_origin(irdl) == types.UnionType: From a0cd4aec9b58a301242f77fb706fbd72f94ae43e Mon Sep 17 00:00:00 2001 From: Fehr Mathieu Date: Wed, 4 May 2022 21:01:02 +0100 Subject: [PATCH 12/36] Fix typo Co-authored-by: Michel Weber <55622065+webmiche@users.noreply.github.com> --- src/xdsl/irdl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index b8dd10d78d..b5ca877529 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -843,7 +843,7 @@ def irdl_param_attr_definition(cls: Type[PA]) -> Type[PA]: if field_name == "name" or field_name == "parameters": continue - # Throw an error if a parameter is not definied using `ParameterDef` + # Throw an error if a parameter is not defined using `ParameterDef` origin = get_origin(field_type) args = get_args(field_type) if origin != Annotated or IRDLAnnotations.ParamDefAnnot not in args: From 9f440ee9dd81e067b60ed63e6c47a721d35b6217 Mon Sep 17 00:00:00 2001 From: Fehr Mathieu Date: Wed, 4 May 2022 21:01:14 +0100 Subject: [PATCH 13/36] Fix typo Co-authored-by: Michel Weber <55622065+webmiche@users.noreply.github.com> --- src/xdsl/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xdsl/util.py b/src/xdsl/util.py index ff11daa3af..91d7d95de0 100644 --- a/src/xdsl/util.py +++ b/src/xdsl/util.py @@ -68,4 +68,4 @@ def is_satisfying_hint(arg, hint) -> bool: annotated_type = type(Annotated[int, 0]) -"""This is the type of an Annotated object.""" \ No newline at end of file +"""This is the type of an Annotated object.""" From 20c21c1f0bdb8b6bff0ddfd1d1c20e77b4c5e4a5 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Sat, 30 Apr 2022 08:33:04 +0100 Subject: [PATCH 14/36] Move verify method to base Attribute class --- src/xdsl/ir.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/xdsl/ir.py b/src/xdsl/ir.py index 708fde0f20..4ddb009f90 100644 --- a/src/xdsl/ir.py +++ b/src/xdsl/ir.py @@ -191,6 +191,16 @@ def build(cls: Type[A], *args: Any) -> A: """Create a new attribute using one of the builder defined in IRDL.""" assert False + def __post_init__(self): + self.verify() + + def verify(self) -> None: + """ + Check that the attribute parameters satisfy the expected invariants. + Raise an exception otherwise. + """ + pass + DataElement = TypeVar("DataElement") @@ -218,12 +228,6 @@ class ParametrizedAttribute(Attribute): parameters: List[Attribute] = field(default_factory=list) - def __post_init__(self): - self.verify() - - def verify(self) -> None: - ... - @dataclass class Operation: From 97349f4003a4e077ed9c8553086f501ee221451d Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Sat, 30 Apr 2022 08:33:36 +0100 Subject: [PATCH 15/36] Derive verifiers for Data definitions --- src/xdsl/irdl.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index b5ca877529..2593bb7d19 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -803,17 +803,47 @@ def irdl_attr_builder(cls, builders, *args): f"No available {cls.__name__} builders for arguments {args}") +def irdl_data_verify(data: Data, typ: Type) -> None: + """Check that the Data has the expected type.""" + if isinstance(data.data, typ): + return + raise VerifyException( + f"{data.name} data attribute expected type {typ}, but {type(data.data)} given." + ) + + T = TypeVar('T') def irdl_data_definition(cls: Type[T]) -> Type[T]: - builders = irdl_get_builders(cls) + new_attrs = dict() + + # Build method is added for all definitions. if "build" in cls.__dict__: raise Exception( f'"build" method for {cls.__name__} is reserved for IRDL, and should not be defined.' ) - new_attrs = dict() + builders = irdl_get_builders(cls) new_attrs["build"] = lambda *args: irdl_attr_builder(cls, builders, *args) + + # Verify method is added if not redefined by the user. + if "verify" not in cls.__dict__: + for parent in cls.__orig_bases__: + if get_origin(parent) != Data: + continue + if len(get_args(parent)) != 1: + raise Exception(f"In {cls.__name__} definition: Data expects " + "a single type parameter") + expected_type = get_args(parent)[0] + new_attrs[ + "verify"] = lambda self, expected_type=expected_type: irdl_data_verify( + self, expected_type) + break + else: + raise Exception(f'Missing method "verify" in {cls.__name__} data ' + 'attribute definition: the "verify" method cannot ' + 'be automatically derived for this definition.') + return dataclass(frozen=True)(type(cls.__name__, (cls, ), { **cls.__dict__, **new_attrs From de79c20c8eadab7a247ed0e179dd55b2bc785da9 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 4 May 2022 16:45:35 -0500 Subject: [PATCH 16/36] Add verifier in ArrayAttr --- src/xdsl/dialects/builtin.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index 318ef76586..382437c4e6 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -198,6 +198,18 @@ def generic_constraint_coercion(args: tuple[Any]) -> AttrConstraint: raise TypeError(f"Attribute ArrayAttr expects at most type" f" parameter, but {len(args)} were given") + def verify(self) -> None: + if not isinstance(self.data, list): + raise VerifyException( + f"Wrong type given to attribute {self.name}: got" + f" {type(self.data)}, but expected list of" + " attributes") + for idx, val in enumerate(self.data): + if not isinstance(val, Attribute): + raise VerifyException( + f"{self.name} data expects attribute list, but {idx} " + f"element is of type {type(val)}") + @staticmethod @builder def from_list(data: List[A]) -> ArrayAttr[A]: From f0df0c3df6d89746818403fc8a35385df98f12ea Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 4 May 2022 16:54:45 -0500 Subject: [PATCH 17/36] Add check that Data parameter is a class --- src/xdsl/irdl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index 2593bb7d19..def99b9563 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -835,6 +835,10 @@ def irdl_data_definition(cls: Type[T]) -> Type[T]: raise Exception(f"In {cls.__name__} definition: Data expects " "a single type parameter") expected_type = get_args(parent)[0] + if not isclass(expected_type): + raise Exception(f'In {cls.__name__} definition: Cannot infer ' + f'"verify" method. Type parameter of Data is ' + f'not a class.') new_attrs[ "verify"] = lambda self, expected_type=expected_type: irdl_data_verify( self, expected_type) From c81c0a774b45ab3c8c79a631294d3514b765c7e5 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Wed, 4 May 2022 17:45:32 -0500 Subject: [PATCH 18/36] Fix punctuation --- src/xdsl/irdl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index def99b9563..3d8ffd4a15 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -216,7 +216,7 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: raise ValueError( f"Generic `Data` type '{origin.name}' cannot be converted to " "an attribute constraint. Consider making it inherit from " - "`GenericData` instead of `Data`") + "`GenericData` instead of `Data`.") # Union case # This is a coercion for an `AnyOf` constraint. From 20d6211c6aac0d25c0e4121ed1995ac1823424ab Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 5 May 2022 00:03:58 -0500 Subject: [PATCH 19/36] Fix verify function in ParametrizedAttribute --- src/xdsl/irdl.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index 3d8ffd4a15..7be1c1dc10 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -7,8 +7,8 @@ from enum import Enum from inspect import isclass from typing import (Annotated, Any, Callable, Dict, List, Optional, Sequence, - Tuple, Type, TypeAlias, TypeVar, TypeGuard, Union, cast, - get_args, get_origin, get_type_hints, ForwardRef) + Tuple, Type, TypeAlias, TypeVar, Union, cast, get_args, + get_origin, get_type_hints) from xdsl import util from xdsl.diagnostic import Diagnostic, DiagnosticException @@ -210,7 +210,10 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: origin = get_origin(irdl) if issubclass(origin, GenericData): - return origin.generic_constraint_coercion(get_args(irdl)) + return AllOf([ + BaseAttr(origin), + origin.generic_constraint_coercion(get_args(irdl)) + ]) if issubclass(origin, Data): raise ValueError( @@ -743,15 +746,7 @@ def irdl_attr_verify(attr: ParametrizedAttribute, ) for idx, param_def in enumerate(parameters): param = attr.parameters[idx] - assert isinstance(param, Attribute) - for arg in get_args(parameters): - if isinstance(param, IRDLAnnotations): - continue - if not isinstance(arg, AttrConstraint): - raise Exception( - "Unexpected attribute constraint given to IRDL definition: {arg}" - ) - arg.verify(param) + param_def.verify(param) C = TypeVar('C', bound=Callable[..., Any]) From dc1bff5cfb59c6aa835866d76c200e891fddab87 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 5 May 2022 00:04:27 -0500 Subject: [PATCH 20/36] Add base file for testing attribute definition --- tests/attribute_definition_test.py | 251 +++++++++++++++++++++++++++++ 1 file changed, 251 insertions(+) create mode 100644 tests/attribute_definition_test.py diff --git a/tests/attribute_definition_test.py b/tests/attribute_definition_test.py new file mode 100644 index 0000000000..3ac0cddeae --- /dev/null +++ b/tests/attribute_definition_test.py @@ -0,0 +1,251 @@ +""" +Test the definition of attributes and their constraints. +""" + +from __future__ import annotations +from dataclasses import dataclass +from io import StringIO +from typing import Any, List, Type, TypeGuard, TypeVar, cast + +import pytest + +from xdsl.ir import Attribute, Data, ParametrizedAttribute +from xdsl.irdl import AllOf, AttrConstraint, BaseAttr, GenericData, ParamAttrConstraint, ParameterDef, VerifyException, attr_constr_coercion, irdl_attr_definition, builder, irdl_to_attr_constraint +from xdsl.parser import Parser +from xdsl.printer import Printer + +# ____ _ +# | _ \ __ _| |_ __ _ +# | | | |/ _` | __/ _` | +# | |_| | (_| | || (_| | +# |____/ \__,_|\__\__,_| +# + + +@irdl_attr_definition +class BoolData(Data[bool]): + """An attribute holding a boolean value.""" + name = "bool" + + @staticmethod + def parse_parameter(parser: Parser) -> bool: + val = parser.parse_optional_ident() + if val == "True": + return True + elif val == "False": + return False + else: + raise Exception("Wrong argument passed to BoolAttr.") + + @staticmethod + def print_parameter(data: bool, printer: Printer): + printer.print_string(str(data)) + + +def test_simple_data(): + b = BoolData(True) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(b) + assert stream.getvalue() == "!bool" + + +def test_simple_data_verifier_failure(): + with pytest.raises(VerifyException) as e: + BoolData(2) + assert e.value.args[0] == ("bool data attribute expected type " + ", but given.") + + +class IntListMissingVerifierData(Data[List[int]]): + """ + An attribute holding a list of integers. + The definition should fail, since no verifier is provided, and the Data + type parameter is not a class. + """ + name = "missing_verifier_data" + + @staticmethod + def parse_parameter(parser: Parser) -> List[int]: + raise NotImplementedError() + + @staticmethod + def print_parameter(data: List[int], printer: Printer) -> None: + raise NotImplementedError() + + +def test_data_with_non_class_param_missing_verifier_failure(): + with pytest.raises(Exception) as e: + irdl_attr_definition(IntListMissingVerifierData) + assert e.value.args[0] == ( + 'In IntListMissingVerifierData definition: ' + 'Cannot infer "verify" method. Type parameter of Data is not a class.') + + +@irdl_attr_definition +class IntListData(Data[List[int]]): + """ + An attribute holding a list of integers. + """ + name = "int_list" + + @staticmethod + def parse_parameter(parser: Parser) -> List[int]: + raise NotImplementedError() + + @staticmethod + def print_parameter(data: List[int], printer: Printer) -> None: + printer.print_string("[") + printer.print_list(data, lambda x: printer.print_string(str(x))) + printer.print_string("]") + + def verify(self) -> None: + if not isinstance(self.data, list): + raise VerifyException("int_list data should hold a list.") + for elem in self.data: + if not isinstance(elem, int): + raise VerifyException( + "int_list list elements should be integers.") + + +def test_non_class_data(): + attr = IntListData([0, 1, 42]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue() == "!int_list<[0, 1, 42]>" + + +def test_simple_data_constructor_failure(): + with pytest.raises(VerifyException) as e: + IntListData([0, 1, 42, ""]) + assert e.value.args[0] == "int_list list elements should be integers." + + +# ____ _ ____ _ +# / ___| ___ _ __ ___ _ __(_) ___| _ \ __ _| |_ __ _ +# | | _ / _ \ '_ \ / _ \ '__| |/ __| | | |/ _` | __/ _` | +# | |_| | __/ | | | __/ | | | (__| |_| | (_| | || (_| | +# \____|\___|_| |_|\___|_| |_|\___|____/ \__,_|\__\__,_| +# + +_MissingGenericDataData = TypeVar("_MissingGenericDataData") + + +class MissingGenericDataData(Data[_MissingGenericDataData]): + name = "array" + + @staticmethod + def parse_parameter(parser: Parser) -> _MissingGenericDataData: + raise NotImplementedError() + + @staticmethod + def print_parameter(data: _MissingGenericDataData, + printer: Printer) -> None: + raise NotImplementedError() + + def verify(self) -> None: + return + + +def test_data_with_generic_missing_generic_data_failure(): + with pytest.raises(Exception) as e: + irdl_to_attr_constraint(MissingGenericDataData[int]) + assert e.value.args[0] == ( + "Generic `Data` type 'array' cannot be converted to an attribute " + "constraint. Consider making it inherit from `GenericData` " + "instead of `Data`.") + + +A = TypeVar("A", bound=Attribute) + + +@dataclass +class DataListAttr(AttrConstraint): + """ + A constraint that enforces that the elements of a ListData all respect + a constraint. + """ + elem_constr: AttrConstraint + + def verify(self, attr: Attribute) -> None: + attr = cast(ListData, attr) + for e in attr.data: + self.elem_constr.verify(e) + + +@irdl_attr_definition +class ListData(GenericData[List[A]]): + name = "list" + + @staticmethod + def parse_parameter(parser: Parser) -> List[A]: + raise NotImplementedError() + + @staticmethod + def print_parameter(data: List[A], printer: Printer) -> None: + printer.print_string("[") + printer.print_list(data, printer.print_attribute) + printer.print_string("]") + + @staticmethod + def generic_constraint_coercion(args: tuple[Any]) -> AttrConstraint: + assert len(args) == 1 + return DataListAttr(irdl_to_attr_constraint(args[0])) + + @staticmethod + @builder + def from_list(data: List[A]) -> ListData[A]: + return ListData(data) + + def verify(self) -> None: + if not isinstance(self.data, list): + raise VerifyException( + f"Wrong type given to attribute {self.name}: got" + f" {type(self.data)}, but expected list of" + " attributes.") + for idx, val in enumerate(self.data): + if not isinstance(val, Attribute): + raise VerifyException( + f"{self.name} data expects attribute list, but element " + f"{idx} is of type {type(val)}.") + + +def test_generic_data_verifier(): + attr = ListData([BoolData(True), ListData([BoolData(False)])]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue() == "!list<[!bool, !list<[!bool]>]>" + + +def test_generic_data_verifier_fail(): + with pytest.raises(VerifyException) as e: + ListData([0]) + assert e.value.args[0] == ("list data expects attribute list, but" + " element 0 is of type .") + + +@irdl_attr_definition +class ListDataWrapper(ParametrizedAttribute): + name = "list_wrapper" + + val: ParameterDef[ListData[BoolData]] + + +def test_generic_data_wrapper_verifier(): + attr = ListDataWrapper([ListData([BoolData(True), BoolData(False)])]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue( + ) == "!list_wrapper, !bool]>>" + + +def test_generic_data_wrapper_verifier_failure(): + with pytest.raises(VerifyException) as e: + ListDataWrapper( + [ListData([BoolData(True), + ListData([BoolData(False)])])]) + assert e.value.args[ + 0] == "ListData(data=[BoolData(data=False)]) should be of base attribute bool" From c33560e7aafaabdb00e3208ea75771292e4f6948 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 5 May 2022 17:26:50 -0500 Subject: [PATCH 21/36] Fix TupleType implementation --- src/xdsl/dialects/builtin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index 382437c4e6..c417f99bb5 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -220,12 +220,12 @@ def from_list(data: List[A]) -> ArrayAttr[A]: class TupleType(ParametrizedAttribute): name = "tuple" - types = ParameterDef(ArrayOfConstraint(Attribute)) + types: ParameterDef[ArrayAttr] @staticmethod @builder def from_type_list(types: List[Attribute]) -> TupleType: - return TupleType([ArrayAttr.from_list(types)]) #type: ignore + return TupleType([ArrayAttr.from_list(types)]) @irdl_attr_definition From e8649e645132316216ff14f699b850f887a971cd Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 5 May 2022 17:28:16 -0500 Subject: [PATCH 22/36] Remove type errors in attribute definition test --- tests/attribute_definition_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/attribute_definition_test.py b/tests/attribute_definition_test.py index 3ac0cddeae..f994446a18 100644 --- a/tests/attribute_definition_test.py +++ b/tests/attribute_definition_test.py @@ -52,7 +52,7 @@ def test_simple_data(): def test_simple_data_verifier_failure(): with pytest.raises(VerifyException) as e: - BoolData(2) + BoolData(2) # type: ignore assert e.value.args[0] == ("bool data attribute expected type " ", but given.") @@ -118,7 +118,7 @@ def test_non_class_data(): def test_simple_data_constructor_failure(): with pytest.raises(VerifyException) as e: - IntListData([0, 1, 42, ""]) + IntListData([0, 1, 42, ""]) # type: ignore assert e.value.args[0] == "int_list list elements should be integers." @@ -221,7 +221,7 @@ def test_generic_data_verifier(): def test_generic_data_verifier_fail(): with pytest.raises(VerifyException) as e: - ListData([0]) + ListData([0]) # type: ignore assert e.value.args[0] == ("list data expects attribute list, but" " element 0 is of type .") From d1b5030bc1eeeb675d6d742d9967df0ab1f604d4 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 5 May 2022 17:38:41 -0500 Subject: [PATCH 23/36] Add new test --- tests/attribute_definition_test.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/attribute_definition_test.py b/tests/attribute_definition_test.py index f994446a18..689c71f3bf 100644 --- a/tests/attribute_definition_test.py +++ b/tests/attribute_definition_test.py @@ -249,3 +249,20 @@ def test_generic_data_wrapper_verifier_failure(): ListData([BoolData(False)])])]) assert e.value.args[ 0] == "ListData(data=[BoolData(data=False)]) should be of base attribute bool" + + +@irdl_attr_definition +class ListDataNoGenericsWrapper(ParametrizedAttribute): + name = "list_no_generics_wrapper" + + val: ParameterDef[ListData] + + +def test_generic_data_no_generics_wrapper_verifier(): + attr = ListDataNoGenericsWrapper( + [ListData([BoolData(True), ListData([BoolData(False)])])]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue( + ) == "!list_no_generics_wrapper, !list<[!bool]>]>>" From a0553a8765148b66757545d10dd4a23d16749e15 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 5 May 2022 17:47:46 -0500 Subject: [PATCH 24/36] Add documentation in attribute definition tests --- tests/attribute_definition_test.py | 50 +++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/tests/attribute_definition_test.py b/tests/attribute_definition_test.py index 689c71f3bf..2bd2c49124 100644 --- a/tests/attribute_definition_test.py +++ b/tests/attribute_definition_test.py @@ -43,6 +43,7 @@ def print_parameter(data: bool, printer: Printer): def test_simple_data(): + """Test that the definition of a data with a class parameter.""" b = BoolData(True) stream = StringIO() p = Printer(stream=stream) @@ -51,6 +52,10 @@ def test_simple_data(): def test_simple_data_verifier_failure(): + """ + Test that the verifier of a data with a class parameter fails when given + a parameter of the wrong type. + """ with pytest.raises(VerifyException) as e: BoolData(2) # type: ignore assert e.value.args[0] == ("bool data attribute expected type " @@ -75,6 +80,9 @@ def print_parameter(data: List[int], printer: Printer) -> None: def test_data_with_non_class_param_missing_verifier_failure(): + """ + Test that a non-class Data parameter requires the definition of a verifier. + """ with pytest.raises(Exception) as e: irdl_attr_definition(IntListMissingVerifierData) assert e.value.args[0] == ( @@ -109,6 +117,7 @@ def verify(self) -> None: def test_non_class_data(): + """Test the definition of a Data with a non-class parameter.""" attr = IntListData([0, 1, 42]) stream = StringIO() p = Printer(stream=stream) @@ -117,6 +126,10 @@ def test_non_class_data(): def test_simple_data_constructor_failure(): + """ + Test that the verifier of a Data with a non-class parameter fails when + given wrong arguments. + """ with pytest.raises(VerifyException) as e: IntListData([0, 1, 42, ""]) # type: ignore assert e.value.args[0] == "int_list list elements should be integers." @@ -132,8 +145,9 @@ def test_simple_data_constructor_failure(): _MissingGenericDataData = TypeVar("_MissingGenericDataData") +@irdl_attr_definition class MissingGenericDataData(Data[_MissingGenericDataData]): - name = "array" + name = "missing_genericdata" @staticmethod def parse_parameter(parser: Parser) -> _MissingGenericDataData: @@ -148,13 +162,23 @@ def verify(self) -> None: return +class MissingGenericDataDataWrapper(ParametrizedAttribute): + name = "missing_genericdata_wrapper" + + param: ParameterDef[MissingGenericDataData[int]] + + def test_data_with_generic_missing_generic_data_failure(): + """ + Test error message when a generic data is used in constraints + without implementing GenericData. + """ with pytest.raises(Exception) as e: - irdl_to_attr_constraint(MissingGenericDataData[int]) + irdl_attr_definition(MissingGenericDataDataWrapper) assert e.value.args[0] == ( - "Generic `Data` type 'array' cannot be converted to an attribute " - "constraint. Consider making it inherit from `GenericData` " - "instead of `Data`.") + "Generic `Data` type 'missing_genericdata' cannot be converted to " + "an attribute constraint. Consider making it inherit from " + "`GenericData` instead of `Data`.") A = TypeVar("A", bound=Attribute) @@ -212,6 +236,9 @@ def verify(self) -> None: def test_generic_data_verifier(): + """ + Test that a GenericData can be created. + """ attr = ListData([BoolData(True), ListData([BoolData(False)])]) stream = StringIO() p = Printer(stream=stream) @@ -220,6 +247,9 @@ def test_generic_data_verifier(): def test_generic_data_verifier_fail(): + """ + Test that a GenericData verifier fails when given wrong parameters. + """ with pytest.raises(VerifyException) as e: ListData([0]) # type: ignore assert e.value.args[0] == ("list data expects attribute list, but" @@ -234,6 +264,9 @@ class ListDataWrapper(ParametrizedAttribute): def test_generic_data_wrapper_verifier(): + """ + Test that a GenericData used in constraints pass the verifier when correct. + """ attr = ListDataWrapper([ListData([BoolData(True), BoolData(False)])]) stream = StringIO() p = Printer(stream=stream) @@ -243,6 +276,10 @@ def test_generic_data_wrapper_verifier(): def test_generic_data_wrapper_verifier_failure(): + """ + Test that a GenericData used in constraints fails + the verifier when constraints are not satisfied. + """ with pytest.raises(VerifyException) as e: ListDataWrapper( [ListData([BoolData(True), @@ -259,6 +296,9 @@ class ListDataNoGenericsWrapper(ParametrizedAttribute): def test_generic_data_no_generics_wrapper_verifier(): + """ + Test that GenericType can be used in constraints without a parameter. + """ attr = ListDataNoGenericsWrapper( [ListData([BoolData(True), ListData([BoolData(False)])])]) stream = StringIO() From 0c9a10e9a178bf56855e7039e76b63c143d707d6 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 5 May 2022 18:26:23 -0500 Subject: [PATCH 25/36] Fix bug in irdl --- src/xdsl/irdl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index 7be1c1dc10..1da1f72175 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -209,13 +209,13 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: return BaseAttr(irdl) origin = get_origin(irdl) - if issubclass(origin, GenericData): + if isclass(origin) and issubclass(origin, GenericData): return AllOf([ BaseAttr(origin), origin.generic_constraint_coercion(get_args(irdl)) ]) - if issubclass(origin, Data): + if isclass(origin) and issubclass(origin, Data): raise ValueError( f"Generic `Data` type '{origin.name}' cannot be converted to " "an attribute constraint. Consider making it inherit from " @@ -223,7 +223,7 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: # Union case # This is a coercion for an `AnyOf` constraint. - if get_origin(irdl) == types.UnionType: + if origin == types.UnionType: constraints: List[AttrConstraint] = [] for arg in get_args(irdl): # We should not try to convert IRDL annotations, which do not From e306d5752fd77797c57905a01e62c380b68d5f40 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 5 May 2022 18:28:20 -0500 Subject: [PATCH 26/36] Add test cases for all IRDL constraints --- tests/attribute_definition_test.py | 145 ++++++++++++++++++++++++++++- 1 file changed, 143 insertions(+), 2 deletions(-) diff --git a/tests/attribute_definition_test.py b/tests/attribute_definition_test.py index 2bd2c49124..6b59308177 100644 --- a/tests/attribute_definition_test.py +++ b/tests/attribute_definition_test.py @@ -5,12 +5,12 @@ from __future__ import annotations from dataclasses import dataclass from io import StringIO -from typing import Any, List, Type, TypeGuard, TypeVar, cast +from typing import Any, List, TypeVar, cast, Annotated import pytest from xdsl.ir import Attribute, Data, ParametrizedAttribute -from xdsl.irdl import AllOf, AttrConstraint, BaseAttr, GenericData, ParamAttrConstraint, ParameterDef, VerifyException, attr_constr_coercion, irdl_attr_definition, builder, irdl_to_attr_constraint +from xdsl.irdl import AttrConstraint, GenericData, ParameterDef, VerifyException, irdl_attr_definition, builder, irdl_to_attr_constraint from xdsl.parser import Parser from xdsl.printer import Printer @@ -42,6 +42,34 @@ def print_parameter(data: bool, printer: Printer): printer.print_string(str(data)) +@irdl_attr_definition +class IntData(Data[int]): + """An attribute holding an integer value.""" + name = "int" + + @staticmethod + def parse_parameter(parser: Parser) -> int: + return parser.parse_int_literal() + + @staticmethod + def print_parameter(data: int, printer: Printer): + printer.print_string(str(data)) + + +@irdl_attr_definition +class StringData(Data[str]): + """An attribute holding a string value.""" + name = "str" + + @staticmethod + def parse_parameter(parser: Parser) -> str: + return parser.parse_str_literal() + + @staticmethod + def print_parameter(data: str, printer: Printer): + printer.print_string(data) + + def test_simple_data(): """Test that the definition of a data with a class parameter.""" b = BoolData(True) @@ -135,6 +163,119 @@ def test_simple_data_constructor_failure(): assert e.value.args[0] == "int_list list elements should be integers." +# ____ ____ _ _ _ +# | __ ) __ _ ___ ___ / ___|___ _ __ ___| |_ _ __ __ _(_)_ __ | |_ +# | _ \ / _` / __|/ _ \ | / _ \| '_ \/ __| __| '__/ _` | | '_ \| __| +# | |_) | (_| \__ \ __/ |__| (_) | | | \__ \ |_| | | (_| | | | | | |_ +# |____/ \__,_|___/\___|\____\___/|_| |_|___/\__|_| \__,_|_|_| |_|\__| +# + + +@irdl_attr_definition +class BoolWrapperAttr(ParametrizedAttribute): + name = "bool_wrapper" + + param: ParameterDef[BoolData] + + +def test_bose_constraint(): + """Test the verifier of a base attribute type constraint.""" + attr = BoolWrapperAttr([BoolData(True)]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue() == "!bool_wrapper>" + + +def test_base_constraint_fail(): + """Test the verifier of a union constraint.""" + with pytest.raises(Exception) as e: + BoolWrapperAttr([StringData("foo")]) + assert e.value.args[ + 0] == "StringData(data='foo') should be of base attribute bool" + + +# _ _ _ ____ _ _ _ +# | | | |_ __ (_) ___ _ __ / ___|___ _ __ ___| |_ _ __ __ _(_)_ __ | |_ +# | | | | '_ \| |/ _ \| '_ \| | / _ \| '_ \/ __| __| '__/ _` | | '_ \| __| +# | |_| | | | | | (_) | | | | |__| (_) | | | \__ \ |_| | | (_| | | | | | |_ +# \___/|_| |_|_|\___/|_| |_|\____\___/|_| |_|___/\__|_| \__,_|_|_| |_|\__| +# + + +@irdl_attr_definition +class BoolOrIntParamAttr(ParametrizedAttribute): + name = "bool_or_int" + + param: ParameterDef[BoolData | IntData] + + +def test_union_constraint_left(): + """Test the verifier of a union constraint.""" + attr = BoolOrIntParamAttr([BoolData(True)]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue() == "!bool_or_int>" + + +def test_union_constraint_right(): + """Test the verifier of a union constraint.""" + attr = BoolOrIntParamAttr([IntData(42)]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue() == "!bool_or_int>" + + +def test_union_constraint_fail(): + """Test the verifier of a union constraint.""" + with pytest.raises(Exception) as e: + BoolOrIntParamAttr([StringData("foo")]) + assert e.value.args[0] == "Unexpected attribute StringData(data='foo')" + + +# _ _ ____ _ +# / \ _ __ _ __ ___ | |_ / ___|___ _ __ ___| |_ _ __ +# / _ \ | '_ \| '_ \ / _ \| __| | / _ \| '_ \/ __| __| '__| +# / ___ \| | | | | | | (_) | |_| |__| (_) | | | \__ \ |_| | +# /_/ \_\_| |_|_| |_|\___/ \__|\____\___/|_| |_|___/\__|_| + + +class PositiveIntConstr(AttrConstraint): + + def verify(self, attr: Attribute) -> None: + if not isinstance(attr, IntData): + raise VerifyException( + f"Expected {IntData.name} attribute, but got {attr.name}.") + if attr.data <= 0: + raise VerifyException( + f"Expected positive integer, got {attr.data}.") + + +@irdl_attr_definition +class PositiveIntAttr(ParametrizedAttribute): + name = "positive_int" + + param: ParameterDef[Annotated[IntData, PositiveIntConstr()]] + + +def test_annotated_constraint(): + """Test the verifier of an annotated constraint.""" + attr = PositiveIntAttr([IntData(42)]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue() == "!positive_int>" + + +def test_annotated_constraint_fail(): + """Test that the verifier of an annotated constraint can fail.""" + with pytest.raises(Exception) as e: + PositiveIntAttr([IntData(-42)]) + assert e.value.args[0] == "Expected positive integer, got -42." + + # ____ _ ____ _ # / ___| ___ _ __ ___ _ __(_) ___| _ \ __ _| |_ __ _ # | | _ / _ \ '_ \ / _ \ '__| |/ __| | | |/ _` | __/ _` | From 26c8f3d436fb131700557cddc3a6e691571a29e3 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Fri, 6 May 2022 12:59:48 -0500 Subject: [PATCH 27/36] Add support for TypeVar in ParametrizedAttribute --- src/xdsl/irdl.py | 139 ++++++++++++++++++++++------ tests/attribute_definition_test.py | 141 ++++++++++++++++++++++++++++- 2 files changed, 248 insertions(+), 32 deletions(-) diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index 1da1f72175..ccbfc1de8c 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -6,9 +6,9 @@ from dataclasses import dataclass, field from enum import Enum from inspect import isclass -from typing import (Annotated, Any, Callable, Dict, List, Optional, Sequence, - Tuple, Type, TypeAlias, TypeVar, Union, cast, get_args, - get_origin, get_type_hints) +from typing import (Annotated, Any, Callable, Dict, Generic, List, Optional, + Sequence, Tuple, Type, TypeAlias, TypeVar, Union, cast, + get_args, get_origin, get_type_hints) from xdsl import util from xdsl.diagnostic import Diagnostic, DiagnosticException @@ -185,7 +185,12 @@ def generic_constraint_coercion(args: tuple[Any]) -> AttrConstraint: """ -def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: +def irdl_to_attr_constraint( + irdl: Any, + *, + allow_type_var: bool = False, + type_var_mapping: Optional[Dict[TypeVar, AttrConstraint]] = None +) -> AttrConstraint: if isinstance(irdl, AttrConstraint): return irdl @@ -198,7 +203,10 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: # correspond to constraints if isinstance(arg, IRDLAnnotations): continue - constraints.append(irdl_to_attr_constraint(arg)) + constraints.append( + irdl_to_attr_constraint(arg, + allow_type_var=allow_type_var, + type_var_mapping=type_var_mapping)) if len(constraints) > 1: return AllOf(constraints) return constraints[0] @@ -208,18 +216,71 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: if isclass(irdl) and issubclass(irdl, Attribute): return BaseAttr(irdl) + # Type variable case + # We take the type variable bound constraint. + if isinstance(irdl, TypeVar): + if not allow_type_var: + raise Exception("TypeVar in unexpected context.") + if type_var_mapping is not None: + if irdl in type_var_mapping: + return type_var_mapping[irdl] + if irdl.__bound__ is None: + raise Exception("Type variables used in IRDL are expected to" + " be bound.") + # We do not allow nested type variables. + return irdl_to_attr_constraint(irdl.__bound__) + origin = get_origin(irdl) + + # GenericData case if isclass(origin) and issubclass(origin, GenericData): return AllOf([ BaseAttr(origin), origin.generic_constraint_coercion(get_args(irdl)) ]) - if isclass(origin) and issubclass(origin, Data): - raise ValueError( - f"Generic `Data` type '{origin.name}' cannot be converted to " - "an attribute constraint. Consider making it inherit from " - "`GenericData` instead of `Data`.") + # Generic ParametrizedAttributes case + # We translate it to constraints over the attribute parameters. + if isclass(origin) and issubclass( + origin, ParametrizedAttribute) and issubclass(origin, Generic): + args = [ + irdl_to_attr_constraint(arg, + allow_type_var=allow_type_var, + type_var_mapping=type_var_mapping) + for arg in get_args(irdl) + ] + generic_args = () + + # Get the Generic parent class to get the TypeVar parameters + for parent in origin.__orig_bases__: # type: ignore + if get_origin(parent) == Generic: + generic_args = get_args(parent) + break + else: + raise Exception( + f"Cannot parametrized non-generic {origin.name} attribute.") + + # Check that we have the right number of parameters + if len(args) != len(generic_args): + raise Exception(f"{origin.name} expects {len(generic_args)}" + f" parameters, got {len(args)}.") + + type_var_mapping = { + parameter: arg + for parameter, arg in zip(generic_args, args) + } + + origin_parameters = irdl_param_attr_get_param_type_hints(origin) + origin_constraints: List[Attribute | Type[Attribute] + | AttrConstraint] = [ + irdl_to_attr_constraint( + param, + allow_type_var=True, + type_var_mapping=type_var_mapping) + for _, param in origin_parameters + ] + print(origin_constraints) + return ParamAttrConstraint(origin, origin_constraints) # Union case # This is a coercion for an `AnyOf` constraint. @@ -230,11 +291,21 @@ def irdl_to_attr_constraint(irdl: Any) -> AttrConstraint: # correspond to constraints if isinstance(arg, IRDLAnnotations): continue - constraints.append(irdl_to_attr_constraint(arg)) + constraints.append( + irdl_to_attr_constraint(arg, + allow_type_var=allow_type_var, + type_var_mapping=type_var_mapping)) if len(constraints) > 1: return AnyOf(constraints) return constraints[0] + # Better error messages for missing GenericData in Data definitions + if isclass(origin) and issubclass(origin, Data): + raise ValueError( + f"Generic `Data` type '{origin.name}' cannot be converted to " + "an attribute constraint. Consider making it inherit from " + "`GenericData` instead of `Data`.") + raise ValueError(f"Unexpected irdl constraint: {irdl}") @@ -849,6 +920,27 @@ def irdl_data_definition(cls: Type[T]) -> Type[T]: })) +def irdl_param_attr_get_param_type_hints( + cls: Type[ParametrizedAttribute]) -> List[Tuple[str, Any]]: + """Get the type hints of an IRDL parameter definitions.""" + res = [] + for field_name, field_type in get_type_hints(cls, + include_extras=True).items(): + if field_name == "name" or field_name == "parameters": + continue + + origin = get_origin(field_type) + args = get_args(field_type) + if origin != Annotated or IRDLAnnotations.ParamDefAnnot not in args: + raise ValueError( + f"In attribute {cls.__name__} definition: Parameter " + + f"definition {field_name} should be defined with " + + f"type `ParameterDef`, got type {field_type}.") + + res.append((field_name, field_type)) + return res + + PA = TypeVar("PA", bound=ParametrizedAttribute) @@ -860,31 +952,18 @@ def irdl_param_attr_definition(cls: Type[PA]) -> Type[PA]: for parent_cls in cls.mro()[::-1]: clsdict = {**clsdict, **parent_cls.__dict__} + param_hints = irdl_param_attr_get_param_type_hints(cls) + # IRDL parameters definitions parameters = [] # New fields and methods added to the attribute new_fields = dict() - # TODO(math-fehr): Check that "name" and "parameters" are not redefined - for field_name, field_type in get_type_hints(cls, - include_extras=True).items(): - # name and parameters are reserved fields in IRDL - if field_name == "name" or field_name == "parameters": - continue - - # Throw an error if a parameter is not defined using `ParameterDef` - origin = get_origin(field_type) - args = get_args(field_type) - if origin != Annotated or IRDLAnnotations.ParamDefAnnot not in args: - raise ValueError( - f"In attribute {cls.__name__} definition: Parameter " + - f"definition {field_name} should be defined with " + - f"type `ParameterDef`, got type {field_type}.") - - # Add the accessors for the definition - new_fields[field_name] = property( + for param_name, param_type in param_hints: + new_fields[param_name] = property( (lambda idx: lambda self: self.parameters[idx])(len(parameters))) - parameters.append(irdl_to_attr_constraint(field_type)) + parameters.append( + irdl_to_attr_constraint(param_type, allow_type_var=True)) new_fields["verify"] = lambda typ: irdl_attr_verify(typ, parameters) diff --git a/tests/attribute_definition_test.py b/tests/attribute_definition_test.py index 6b59308177..d8272ecff4 100644 --- a/tests/attribute_definition_test.py +++ b/tests/attribute_definition_test.py @@ -5,12 +5,14 @@ from __future__ import annotations from dataclasses import dataclass from io import StringIO -from typing import Any, List, TypeVar, cast, Annotated +from typing import Any, List, TypeVar, cast, Annotated, Generic import pytest from xdsl.ir import Attribute, Data, ParametrizedAttribute -from xdsl.irdl import AttrConstraint, GenericData, ParameterDef, VerifyException, irdl_attr_definition, builder, irdl_to_attr_constraint +from xdsl.irdl import (AttrConstraint, GenericData, ParameterDef, + VerifyException, irdl_attr_definition, builder, + irdl_to_attr_constraint) from xdsl.parser import Parser from xdsl.printer import Printer @@ -276,6 +278,141 @@ def test_annotated_constraint_fail(): assert e.value.args[0] == "Expected positive integer, got -42." +# _____ __ __ ____ _ +# |_ _| _ _ __ __\ \ / /_ _ _ __ / ___|___ _ __ ___| |_ _ __ +# | || | | | '_ \ / _ \ \ / / _` | '__| | / _ \| '_ \/ __| __| '__| +# | || |_| | |_) | __/\ V / (_| | | | |__| (_) | | | \__ \ |_| | +# |_| \__, | .__/ \___| \_/ \__,_|_| \____\___/|_| |_|___/\__|_| +# |___/|_| +# + +_T = TypeVar("_T", bound=BoolData | IntData) + + +@irdl_attr_definition +class ParamWrapperAttr(Generic[_T], ParametrizedAttribute): + name = "int_or_bool_generic" + + param: ParameterDef[_T] + + +def test_typevar_attribute_int(): + """Test the verifier of a generic attribute.""" + attr = ParamWrapperAttr([IntData(42)]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue() == "!int_or_bool_generic>" + + +def test_typevar_attribute_bool(): + """Test the verifier of a generic attribute.""" + attr = ParamWrapperAttr([BoolData(True)]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue() == "!int_or_bool_generic>" + + +def test_typevar_attribute_fail(): + """Test that the verifier of an generic attribute can fail.""" + with pytest.raises(Exception) as e: + ParamWrapperAttr([StringData("foo")]) + assert e.value.args[0] == "Unexpected attribute StringData(data='foo')" + + +@irdl_attr_definition +class ParamConstrAttr(ParametrizedAttribute): + name = "param_constr" + + param: ParameterDef[ParamWrapperAttr[IntData]] + + +def test_param_attr_constraint(): + """Test the verifier of an attribute with a parametric constraint.""" + attr = ParamConstrAttr([ParamWrapperAttr([IntData(42)])]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue() == "!param_constr>>" + + +def test_param_attr_constraint_fail(): + """ + Test that the verifier of an attribute with + a parametric constraint can fail. + """ + with pytest.raises(Exception) as e: + ParamConstrAttr([ParamWrapperAttr([BoolData(True)])]) + assert e.value.args[ + 0] == "BoolData(data=True) should be of base attribute int" + + +_U = TypeVar("_U", bound=IntData) + + +@irdl_attr_definition +class NestedParamWrapperAttr(Generic[_U], ParametrizedAttribute): + name = "nested_param_wrapper" + + param: ParameterDef[ParamWrapperAttr[_U]] + + +def test_nested_generic_constraint(): + """ + Test the verifier of an attribute with a generic + constraint used in a parametric constraint. + """ + attr = NestedParamWrapperAttr([ParamWrapperAttr([IntData(42)])]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue( + ) == "!nested_param_wrapper>>" + + +def test_nested_generic_constraint_fail(): + """ + Test that the verifier of an attribute with + a parametric constraint can fail. + """ + with pytest.raises(Exception) as e: + NestedParamWrapperAttr([ParamWrapperAttr([BoolData(True)])]) + assert e.value.args[ + 0] == "BoolData(data=True) should be of base attribute int" + + +@irdl_attr_definition +class NestedParamConstrAttr(ParametrizedAttribute): + name = "nested_param_constr" + + param: ParameterDef[NestedParamWrapperAttr[Annotated[IntData, + PositiveIntConstr()]]] + + +def test_nested_param_attr_constraint(): + """ + Test the verifier of a nested parametric constraint. + """ + attr = NestedParamConstrAttr( + [NestedParamWrapperAttr([ParamWrapperAttr([IntData(42)])])]) + stream = StringIO() + p = Printer(stream=stream) + p.print_attribute(attr) + assert stream.getvalue( + ) == "!nested_param_constr>>>" + + +def test_nested_param_attr_constraint_fail(): + """ + Test that the verifier of a nested parametric constraint can fail. + """ + with pytest.raises(Exception) as e: + NestedParamConstrAttr( + [NestedParamWrapperAttr([ParamWrapperAttr([IntData(-42)])])]) + assert e.value.args[0] == "Expected positive integer, got -42." + + # ____ _ ____ _ # / ___| ___ _ __ ___ _ __(_) ___| _ \ __ _| |_ __ _ # | | _ / _ \ '_ \ / _ \ '__| |/ __| | | |/ _` | __/ _` | From a783d4d07ec792e090c349d9727ba7da9a9f1df2 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 12 May 2022 17:45:34 -0500 Subject: [PATCH 28/36] Cleanup a bit affine dialect --- src/xdsl/dialects/affine.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/xdsl/dialects/affine.py b/src/xdsl/dialects/affine.py index 7e481617ef..2cd1d9898b 100644 --- a/src/xdsl/dialects/affine.py +++ b/src/xdsl/dialects/affine.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import Union -from xdsl.ir import Operation, SSAValue +from typing import Union, List from dataclasses import dataclass -from xdsl.dialects.builtin import IntegerAttr, IndexType -from xdsl.irdl import irdl_op_definition, AttributeDef, OperandDef, RegionDef, VarResultDef, VarOperandDef, AnyAttr + +from xdsl.ir import Operation, SSAValue, MLContext, Block, Region +from xdsl.dialects.builtin import IntegerAttr, IndexType, IndexType +from xdsl.irdl import irdl_op_definition, AttributeDef, RegionDef, VarResultDef, VarOperandDef, AnyAttr @dataclass @@ -68,7 +69,7 @@ def from_region(operands: List[Union[Operation, SSAValue]], def from_callable(operands: List[Union[Operation, SSAValue]], lower_bound: Union[int, IntegerAttr], upper_bound: Union[int, IntegerAttr], - body: Callable[[BlockArgument, ...], List[Operation]], + body: Block.BlockCallback, step: Union[int, IntegerAttr] = 1) -> For: arg_types = [IndexType()] + [SSAValue.get(op).typ for op in operands] return For.from_region( @@ -83,5 +84,6 @@ class Yield(Operation): arguments = VarOperandDef(AnyAttr()) @staticmethod - def get(*operands: Union[Operation, SSAValue]) -> Yield: - return Yield.create(operands=[operand for operand in operands]) + def get(*operands: SSAValue | Operation) -> Yield: + return Yield.create( + operands=[SSAValue.get(operand) for operand in operands]) From a990fd994b4ce8bea890517b8f9bce5a6057c260 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 12 May 2022 17:50:12 -0500 Subject: [PATCH 29/36] Cleanup a bit arith dialect --- src/xdsl/dialects/arith.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/xdsl/dialects/arith.py b/src/xdsl/dialects/arith.py index 0f96590f88..62d675edc4 100644 --- a/src/xdsl/dialects/arith.py +++ b/src/xdsl/dialects/arith.py @@ -2,7 +2,7 @@ from xdsl.ir import * from xdsl.irdl import * from xdsl.util import * -from xdsl.dialects.builtin import IntegerType, Float32Type, IntegerAttr, FlatSymbolRefAttr +from xdsl.dialects.builtin import IntegerType, Float32Type, IntegerAttr @dataclass @@ -125,9 +125,9 @@ def verify_(self) -> None: @staticmethod def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> FloordiviSigned: - return FloordiviSigned.build(operands=[operand1, operand2], - result_types=[IntegerType.from_width(32)]) + operand2: Union[Operation, SSAValue]) -> FloorDiviSI: + return FloorDiviSI.build(operands=[operand1, operand2], + result_types=[IntegerType.from_width(32)]) @irdl_op_definition @@ -145,9 +145,9 @@ def verify_(self) -> None: @staticmethod def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> RemiSigned: - return RemiSigned.build(operands=[operand1, operand2], - result_types=[IntegerType.from_width(32)]) + operand2: Union[Operation, SSAValue]) -> RemSI: + return RemSI.build(operands=[operand1, operand2], + result_types=[IntegerType.from_width(32)]) @irdl_op_definition @@ -185,10 +185,10 @@ def verify_(self) -> None: @staticmethod def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Or: + operand2: Union[Operation, SSAValue]) -> OrI: - return Or.build(operands=[operand1, operand2], - result_types=[SSAValue.get(operand1).typ]) + return OrI.build(operands=[operand1, operand2], + result_types=[SSAValue.get(operand1).typ]) @irdl_op_definition @@ -206,9 +206,9 @@ def verify_(self) -> None: @staticmethod def get(operand1: Union[Operation, SSAValue], - operand2: Union[Operation, SSAValue]) -> Xor: - return Xor.build(operands=[operand1, operand2], - result_types=[SSAValue.get(operand1).typ]) + operand2: Union[Operation, SSAValue]) -> XOrI: + return XOrI.build(operands=[operand1, operand2], + result_types=[SSAValue.get(operand1).typ]) @irdl_op_definition From 66814a1a6345a8f3cc8cf4584b6887bc5cf4e84e Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 12 May 2022 18:01:15 -0500 Subject: [PATCH 30/36] Make some attribute generic --- src/xdsl/dialects/builtin.py | 39 +++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index c417f99bb5..8926c8fe7d 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -125,11 +125,14 @@ class IndexType(ParametrizedAttribute): name = "index" +_IntegerAttrTyp = TypeVar("_IntegerAttrTyp", bound=IntegerType | IndexType) + + @irdl_attr_definition -class IntegerAttr(ParametrizedAttribute): +class IntegerAttr(Generic[_IntegerAttrTyp], ParametrizedAttribute): name = "integer" value: ParameterDef[IntAttr] - typ: ParameterDef[IntegerType | IndexType] + typ: ParameterDef[_IntegerAttrTyp] @staticmethod @builder @@ -228,12 +231,15 @@ def from_type_list(types: List[Attribute]) -> TupleType: return TupleType([ArrayAttr.from_list(types)]) +_VectorTypeElems = TypeVar("_VectorTypeElems", bound=Attribute) + + @irdl_attr_definition -class VectorType(ParametrizedAttribute): +class VectorType(Generic[_VectorTypeElems], ParametrizedAttribute): name = "vector" shape: ParameterDef[ArrayAttr[IntegerAttr]] - element_type: ParameterDef[Attribute] + element_type: ParameterDef[_VectorTypeElems] def get_num_dims(self) -> int: return len(self.shape.data) @@ -244,8 +250,9 @@ def get_shape(self) -> List[int]: @staticmethod @builder def from_type_and_list( - referenced_type: Attribute, - shape: Optional[List[int | IntegerAttr]] = None) -> VectorType: + referenced_type: _VectorTypeElems, + shape: Optional[List[int | IntegerAttr]] = None + ) -> VectorType[_VectorTypeElems]: if shape is None: shape = [1] return VectorType([ @@ -256,19 +263,22 @@ def from_type_and_list( @staticmethod @builder def from_params( - referenced_type: Attribute, + referenced_type: _VectorTypeElems, shape: ArrayAttr[IntegerAttr] = ArrayAttr.from_list( [IntegerAttr.from_int_and_width(1, 64)]) - ) -> VectorType: + ) -> VectorType[_VectorTypeElems]: return VectorType([shape, referenced_type]) +_VectorTypeElems = TypeVar("_VectorTypeElems", bound=Attribute) + + @irdl_attr_definition -class TensorType(ParametrizedAttribute): +class TensorType(Generic[_VectorTypeElems], ParametrizedAttribute): name = "tensor" shape: ParameterDef[ArrayAttr[IntegerAttr]] - element_type: ParameterDef[Attribute] + element_type: ParameterDef[_VectorTypeElems] def get_num_dims(self) -> int: return len(self.shape.data) @@ -279,8 +289,9 @@ def get_shape(self) -> List[int]: @staticmethod @builder def from_type_and_list( - referenced_type: Attribute, - shape: Optional[Sequence[int | IntegerAttr]] = None) -> TensorType: + referenced_type: _VectorTypeElems, + shape: Optional[Sequence[int | IntegerAttr]] = None + ) -> TensorType[_VectorTypeElems]: if shape is None: shape = [1] return TensorType([ @@ -291,10 +302,10 @@ def from_type_and_list( @staticmethod @builder def from_params( - referenced_type: Attribute, + referenced_type: _VectorTypeElems, shape: ArrayAttr[IntegerAttr] = ArrayAttr.from_list( [IntegerAttr.from_int_and_width(1, 64)]) - ) -> TensorType: + ) -> TensorType[_VectorTypeElems]: return TensorType([shape, referenced_type]) From 1b2c7c84b11290cb885a40662e5a45158721059a Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 12 May 2022 18:02:02 -0500 Subject: [PATCH 31/36] Cleanup a bit cf --- src/xdsl/dialects/cf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/xdsl/dialects/cf.py b/src/xdsl/dialects/cf.py index 3fe91f3348..dfbc470596 100644 --- a/src/xdsl/dialects/cf.py +++ b/src/xdsl/dialects/cf.py @@ -2,7 +2,7 @@ from xdsl.ir import * from xdsl.irdl import * from xdsl.util import * -from xdsl.dialects.builtin import IntegerType, Float32Type, IntegerAttr, FlatSymbolRefAttr +from xdsl.dialects.builtin import IntegerType @dataclass From fecbc7179ef1c9699ed3592f2326622ee025d089 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 12 May 2022 18:03:50 -0500 Subject: [PATCH 32/36] Cleanup a bit func dialect --- src/xdsl/dialects/func.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/xdsl/dialects/func.py b/src/xdsl/dialects/func.py index 5f15942551..25366cd77f 100644 --- a/src/xdsl/dialects/func.py +++ b/src/xdsl/dialects/func.py @@ -26,10 +26,9 @@ class FuncOp(Operation): sym_visibility = AttributeDef(StringAttr) @staticmethod - def from_callable( - name: str, input_types: List[Attribute], - return_types: List[Attribute], - func: Callable[[BlockArgument, ...], List[Operation]]) -> FuncOp: + def from_callable(name: str, input_types: List[Attribute], + return_types: List[Attribute], + func: Block.BlockCallback) -> FuncOp: type_attr = FunctionType.from_lists(input_types, return_types) op = FuncOp.build(attributes={ "sym_name": name, From 3383ffdc2074f76d7ec55ad86cd0bb670c740b01 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 12 May 2022 18:10:04 -0500 Subject: [PATCH 33/36] Make MemRefTypeElement generic --- src/xdsl/dialects/memref.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/xdsl/dialects/memref.py b/src/xdsl/dialects/memref.py index ae6f4a7628..b1b0ba43d0 100644 --- a/src/xdsl/dialects/memref.py +++ b/src/xdsl/dialects/memref.py @@ -22,12 +22,15 @@ def __post_init__(self): self.ctx.register_op(Global) +_MemRefTypeElement = TypeVar("_MemRefTypeElement", bound=Attribute) + + @irdl_attr_definition -class MemRefType(ParametrizedAttribute): +class MemRefType(Generic[_MemRefTypeElement], ParametrizedAttribute): name = "memref" shape: ParameterDef[ArrayAttr[IntegerAttr]] - element_type: ParameterDef[Attribute] + element_type: ParameterDef[_MemRefTypeElement] def get_num_dims(self) -> int: return len(self.shape.data) @@ -38,8 +41,9 @@ def get_shape(self) -> List[int]: @staticmethod @builder def from_type_and_list( - referenced_type: Attribute, - shape: Optional[List[int | IntegerAttr]] = None) -> MemRefType: + referenced_type: _MemRefTypeElement, + shape: Optional[List[int | IntegerAttr]] = None + ) -> MemRefType[_MemRefTypeElement]: if shape is None: shape = [1] return MemRefType([ @@ -50,10 +54,10 @@ def from_type_and_list( @staticmethod @builder def from_params( - referenced_type: Attribute, + referenced_type: _MemRefTypeElement, shape: ArrayAttr[IntegerAttr] = ArrayAttr.from_list( [IntegerAttr.from_int_and_width(1, 64)]) - ) -> MemRefType: + ) -> MemRefType[_MemRefTypeElement]: return MemRefType([shape, referenced_type]) From a17b7c256ed9539dcca4256ab8778f1e1a63b96e Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Thu, 12 May 2022 18:14:11 -0500 Subject: [PATCH 34/36] Cleanup a bit scf dialect --- src/xdsl/dialects/builtin.py | 82 +++++++++++++++++++----------- src/xdsl/dialects/memref.py | 14 ++--- src/xdsl/dialects/scf.py | 23 +++++---- tests/attribute_definition_test.py | 10 ++-- 4 files changed, 76 insertions(+), 53 deletions(-) diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index 8926c8fe7d..6951135e16 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -3,6 +3,7 @@ from xdsl.irdl import * from xdsl.ir import * +from typing import TypeAlias if TYPE_CHECKING: from xdsl.parser import Parser @@ -125,7 +126,9 @@ class IndexType(ParametrizedAttribute): name = "index" -_IntegerAttrTyp = TypeVar("_IntegerAttrTyp", bound=IntegerType | IndexType) +_IntegerAttrTyp = TypeVar("_IntegerAttrTyp", + bound=IntegerType | IndexType, + covariant=True) @irdl_attr_definition @@ -136,25 +139,30 @@ class IntegerAttr(Generic[_IntegerAttrTyp], ParametrizedAttribute): @staticmethod @builder - def from_int_and_width(value: int, width: int) -> IntegerAttr: + def from_int_and_width(value: int, width: int) -> IntegerAttr[IntegerType]: return IntegerAttr( [IntAttr.from_int(value), IntegerType.from_width(width)]) @staticmethod @builder - def from_index_int_value(value: int) -> IntegerAttr: + def from_index_int_value(value: int) -> IntegerAttr[IndexType]: return IntegerAttr([IntAttr.from_int(value), IndexType()]) @staticmethod @builder - def from_params(value: int | IntAttr, typ: int | Attribute) -> IntegerAttr: + def from_params( + value: int | IntAttr, + typ: int | Attribute) -> IntegerAttr[IntegerType | IndexType]: value = IntAttr.build(value) if not isinstance(typ, IndexType): typ = IntegerType.build(typ) return IntegerAttr([value, typ]) +AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType] + + @dataclass class ArrayOfConstraint(AttrConstraint): """ @@ -173,21 +181,24 @@ def verify(self, attr: Attribute) -> None: self.elem_constr.verify(e) +_ArrayAttrT = TypeVar("_ArrayAttrT", bound=Attribute, covariant=True) + + @irdl_attr_definition -class ArrayAttr(GenericData[List[A]]): +class ArrayAttr(GenericData[List[_ArrayAttrT]]): name = "array" @staticmethod - def parse_parameter(parser: Parser) -> List[A]: + def parse_parameter(parser: Parser) -> List[_ArrayAttrT]: parser.parse_char("[") data = parser.parse_list(parser.parse_optional_attribute) parser.parse_char("]") # the type system can't ensure that the elements are of type A # and not just of type Attribute, therefore, the following cast - return cast(List[A], data) + return cast(List[_ArrayAttrT], data) @staticmethod - def print_parameter(data: List[A], printer: Printer) -> None: + def print_parameter(data: List[_ArrayAttrT], printer: Printer) -> None: printer.print_string("[") printer.print_list(data, printer.print_attribute) printer.print_string("]") @@ -215,15 +226,18 @@ def verify(self) -> None: @staticmethod @builder - def from_list(data: List[A]) -> ArrayAttr[A]: + def from_list(data: List[_ArrayAttrT]) -> ArrayAttr[_ArrayAttrT]: return ArrayAttr(data) +AnyArrayAttr: TypeAlias = ArrayAttr[Attribute] + + @irdl_attr_definition class TupleType(ParametrizedAttribute): name = "tuple" - types: ParameterDef[ArrayAttr] + types: ParameterDef[ArrayAttr[Attribute]] @staticmethod @builder @@ -238,7 +252,7 @@ def from_type_list(types: List[Attribute]) -> TupleType: class VectorType(Generic[_VectorTypeElems], ParametrizedAttribute): name = "vector" - shape: ParameterDef[ArrayAttr[IntegerAttr]] + shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] element_type: ParameterDef[_VectorTypeElems] def get_num_dims(self) -> int: @@ -251,12 +265,13 @@ def get_shape(self) -> List[int]: @builder def from_type_and_list( referenced_type: _VectorTypeElems, - shape: Optional[List[int | IntegerAttr]] = None + shape: Optional[List[int | IntegerAttr[IndexType]]] = None ) -> VectorType[_VectorTypeElems]: if shape is None: shape = [1] return VectorType([ - ArrayAttr.from_list([IntegerAttr.build(d) for d in shape]), + ArrayAttr.from_list( + [IntegerAttr[IntegerType].build(d) for d in shape]), referenced_type ]) @@ -264,21 +279,23 @@ def from_type_and_list( @builder def from_params( referenced_type: _VectorTypeElems, - shape: ArrayAttr[IntegerAttr] = ArrayAttr.from_list( + shape: ArrayAttr[IntegerAttr[IntegerType]] = ArrayAttr.from_list( [IntegerAttr.from_int_and_width(1, 64)]) ) -> VectorType[_VectorTypeElems]: return VectorType([shape, referenced_type]) -_VectorTypeElems = TypeVar("_VectorTypeElems", bound=Attribute) +AnyVectorType: TypeAlias = VectorType[Attribute] + +_TensorTypeElems = TypeVar("_TensorTypeElems", bound=Attribute, covariant=True) @irdl_attr_definition -class TensorType(Generic[_VectorTypeElems], ParametrizedAttribute): +class TensorType(Generic[_TensorTypeElems], ParametrizedAttribute): name = "tensor" - shape: ParameterDef[ArrayAttr[IntegerAttr]] - element_type: ParameterDef[_VectorTypeElems] + shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] + element_type: ParameterDef[_TensorTypeElems] def get_num_dims(self) -> int: return len(self.shape.data) @@ -289,13 +306,14 @@ def get_shape(self) -> List[int]: @staticmethod @builder def from_type_and_list( - referenced_type: _VectorTypeElems, - shape: Optional[Sequence[int | IntegerAttr]] = None - ) -> TensorType[_VectorTypeElems]: + referenced_type: _TensorTypeElems, + shape: Optional[Sequence[int | IntegerAttr[IndexType]]] = None + ) -> TensorType[_TensorTypeElems]: if shape is None: shape = [1] return TensorType([ - ArrayAttr.from_list([IntegerAttr.build(d) for d in shape]), + ArrayAttr.from_list( + [IntegerAttr[IndexType].build(d) for d in shape]), referenced_type ]) @@ -303,23 +321,26 @@ def from_type_and_list( @builder def from_params( referenced_type: _VectorTypeElems, - shape: ArrayAttr[IntegerAttr] = ArrayAttr.from_list( + shape: AnyArrayAttr = AnyArrayAttr.from_list( [IntegerAttr.from_int_and_width(1, 64)]) ) -> TensorType[_VectorTypeElems]: return TensorType([shape, referenced_type]) +AnyTensorType: TypeAlias = TensorType[Attribute] + + @irdl_attr_definition class DenseIntOrFPElementsAttr(ParametrizedAttribute): name = "dense" # TODO add support for FPElements - type: ParameterDef[VectorType | TensorType] + type: ParameterDef[AnyVectorType | AnyTensorType] # TODO add support for multi-dimensional data - data: ParameterDef[ArrayAttr[IntegerAttr]] + data: ParameterDef[AnyArrayAttr] @staticmethod @builder - def from_int_list(type: VectorType | TensorType, data: List[int], + def from_int_list(type: AnyVectorType | AnyTensorType, data: List[int], bitwidth: int) -> DenseIntOrFPElementsAttr: data_attr = [IntegerAttr.from_int_and_width(d, bitwidth) for d in data] return DenseIntOrFPElementsAttr([type, ArrayAttr.from_list(data_attr)]) @@ -327,8 +348,9 @@ def from_int_list(type: VectorType | TensorType, data: List[int], @staticmethod @builder def from_list( - type: VectorType | TensorType, - data: List[int] | List[IntegerAttr]) -> DenseIntOrFPElementsAttr: + type: AnyVectorType | AnyTensorType, + data: List[int] | List[AnyIntegerAttr] + ) -> DenseIntOrFPElementsAttr: element_type = type.element_type # Only use the element_type if the passed data is an int, o/w use the IntegerAttr data_attr = [(IntegerAttr.from_params(d, element_type) if isinstance( @@ -340,7 +362,7 @@ def from_list( def vector_from_list( data: List[int], typ: IntegerType | IndexType) -> DenseIntOrFPElementsAttr: - t = VectorType.from_type_and_list(typ, [len(data)]) + t = AnyVectorType.from_type_and_list(typ, [len(data)]) return DenseIntOrFPElementsAttr.from_list(t, data) @staticmethod @@ -348,7 +370,7 @@ def vector_from_list( def tensor_from_list( data: List[int], typ: IntegerType | IndexType) -> DenseIntOrFPElementsAttr: - t = TensorType.from_type_and_list(typ, [len(data)]) + t = AnyTensorType.from_type_and_list(typ, [len(data)]) return DenseIntOrFPElementsAttr.from_list(t, data) diff --git a/src/xdsl/dialects/memref.py b/src/xdsl/dialects/memref.py index b1b0ba43d0..3a3c309bdb 100644 --- a/src/xdsl/dialects/memref.py +++ b/src/xdsl/dialects/memref.py @@ -29,7 +29,7 @@ def __post_init__(self): class MemRefType(Generic[_MemRefTypeElement], ParametrizedAttribute): name = "memref" - shape: ParameterDef[ArrayAttr[IntegerAttr]] + shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] element_type: ParameterDef[_MemRefTypeElement] def get_num_dims(self) -> int: @@ -42,20 +42,20 @@ def get_shape(self) -> List[int]: @builder def from_type_and_list( referenced_type: _MemRefTypeElement, - shape: Optional[List[int | IntegerAttr]] = None + shape: Optional[List[int | AnyIntegerAttr]] = None ) -> MemRefType[_MemRefTypeElement]: if shape is None: shape = [1] return MemRefType([ - ArrayAttr.from_list([IntegerAttr.build(d) for d in shape]), - referenced_type + ArrayAttr[AnyIntegerAttr].from_list( + [IntegerAttr.build(d) for d in shape]), referenced_type ]) @staticmethod @builder def from_params( referenced_type: _MemRefTypeElement, - shape: ArrayAttr[IntegerAttr] = ArrayAttr.from_list( + shape: ArrayAttr[AnyIntegerAttr] = ArrayAttr.from_list( [IntegerAttr.from_int_and_width(1, 64)]) ) -> MemRefType[_MemRefTypeElement]: return MemRefType([shape, referenced_type]) @@ -124,7 +124,7 @@ class Alloc(Operation): @staticmethod def get(return_type: Attribute, alignment: int, - shape: Optional[List[int | IntegerAttr]] = None) -> Alloc: + shape: Optional[List[int | AnyIntegerAttr]] = None) -> Alloc: if shape is None: shape = [1] return Alloc.build( @@ -152,7 +152,7 @@ class Alloca(Operation): @staticmethod def get(return_type: Attribute, alignment: int, - shape: Optional[List[int | IntegerAttr]] = None) -> Alloca: + shape: Optional[List[int | AnyIntegerAttr]] = None) -> Alloca: if shape is None: shape = [1] return Alloca.build( diff --git a/src/xdsl/dialects/scf.py b/src/xdsl/dialects/scf.py index 48bd7146b9..958178a491 100644 --- a/src/xdsl/dialects/scf.py +++ b/src/xdsl/dialects/scf.py @@ -24,9 +24,9 @@ class If(Operation): false_region = RegionDef() @staticmethod - def get(cond: Union[Operation, SSAValue], return_types: List[Attribute], - true_region: Union[Region, List[Block], List[Operation]], - false_region: Union[Region, List[Block], List[Operation]]): + def get(cond: SSAValue | Operation, return_types: List[Attribute], + true_region: Region | List[Block] | List[Operation], + false_region: Region | List[Block] | List[Operation]): return If.build(operands=[cond], result_types=[return_types], regions=[true_region, false_region]) @@ -38,8 +38,9 @@ class Yield(Operation): arguments = VarOperandDef(AnyAttr()) @staticmethod - def get(*operands: Union[Operation, SSAValue]) -> Yield: - return Yield.create(operands=[operand for operand in operands]) + def get(*operands: SSAValue | Operation) -> Yield: + return Yield.create( + operands=[SSAValue.get(operand) for operand in operands]) @irdl_op_definition @@ -49,8 +50,8 @@ class Condition(Operation): arguments = VarOperandDef(AnyAttr()) @staticmethod - def get(cond: Union[Operation, SSAValue], - *output_ops: Union[Operation, SSAValue]) -> Condition: + def get(cond: SSAValue | Operation, + *output_ops: SSAValue | Operation) -> Condition: return Condition.build( operands=[cond, [output for output in output_ops]]) @@ -79,10 +80,10 @@ def verify_(self): ) @staticmethod - def get(operands: List[Union[Operation, - SSAValue]], result_types: List[Attribute], - before: Union[Region, List[Operation], List[Block]], - after: Union[Region, List[Operation], List[Block]]) -> While: + def get(operands: List[SSAValue | Operation], + result_types: List[Attribute], + before: Region | List[Operation] | List[Block], + after: Region | List[Operation] | List[Block]) -> While: op = While.build(operands=operands, result_types=result_types, regions=[before, after]) diff --git a/tests/attribute_definition_test.py b/tests/attribute_definition_test.py index d8272ecff4..18eff5a999 100644 --- a/tests/attribute_definition_test.py +++ b/tests/attribute_definition_test.py @@ -298,7 +298,7 @@ class ParamWrapperAttr(Generic[_T], ParametrizedAttribute): def test_typevar_attribute_int(): """Test the verifier of a generic attribute.""" - attr = ParamWrapperAttr([IntData(42)]) + attr = ParamWrapperAttr[IntData]([IntData(42)]) stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) @@ -307,7 +307,7 @@ def test_typevar_attribute_int(): def test_typevar_attribute_bool(): """Test the verifier of a generic attribute.""" - attr = ParamWrapperAttr([BoolData(True)]) + attr = ParamWrapperAttr[BoolData]([BoolData(True)]) stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) @@ -363,7 +363,7 @@ def test_nested_generic_constraint(): Test the verifier of an attribute with a generic constraint used in a parametric constraint. """ - attr = NestedParamWrapperAttr([ParamWrapperAttr([IntData(42)])]) + attr = NestedParamWrapperAttr[IntData]([ParamWrapperAttr([IntData(42)])]) stream = StringIO() p = Printer(stream=stream) p.print_attribute(attr) @@ -471,7 +471,7 @@ class DataListAttr(AttrConstraint): elem_constr: AttrConstraint def verify(self, attr: Attribute) -> None: - attr = cast(ListData, attr) + attr = cast(ListData[Any], attr) for e in attr.data: self.elem_constr.verify(e) @@ -570,7 +570,7 @@ def test_generic_data_wrapper_verifier_failure(): class ListDataNoGenericsWrapper(ParametrizedAttribute): name = "list_no_generics_wrapper" - val: ParameterDef[ListData] + val: ParameterDef[ListData[Any]] def test_generic_data_no_generics_wrapper_verifier(): From fc328b7748bd9b0bbe9e425c3305d1f945a207d4 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Sat, 4 Jun 2022 19:17:56 +0200 Subject: [PATCH 35/36] Support PyRight strict typing --- src/xdsl/irdl.py | 3 +-- tests/attribute_definition_test.py | 7 +++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/xdsl/irdl.py b/src/xdsl/irdl.py index ccbfc1de8c..e6c3a54a18 100644 --- a/src/xdsl/irdl.py +++ b/src/xdsl/irdl.py @@ -279,12 +279,11 @@ def irdl_to_attr_constraint( type_var_mapping=type_var_mapping) for _, param in origin_parameters ] - print(origin_constraints) return ParamAttrConstraint(origin, origin_constraints) # Union case # This is a coercion for an `AnyOf` constraint. - if origin == types.UnionType: + if origin == types.UnionType or origin == Union: constraints: List[AttrConstraint] = [] for arg in get_args(irdl): # We should not try to convert IRDL annotations, which do not diff --git a/tests/attribute_definition_test.py b/tests/attribute_definition_test.py index 18eff5a999..332289c8c1 100644 --- a/tests/attribute_definition_test.py +++ b/tests/attribute_definition_test.py @@ -5,7 +5,7 @@ from __future__ import annotations from dataclasses import dataclass from io import StringIO -from typing import Any, List, TypeVar, cast, Annotated, Generic +from typing import Any, List, TypeVar, cast, Annotated, Generic, TypeAlias import pytest @@ -513,6 +513,9 @@ def verify(self) -> None: f"{idx} is of type {type(val)}.") +AnyListData: TypeAlias = ListData[Attribute] + + def test_generic_data_verifier(): """ Test that a GenericData can be created. @@ -570,7 +573,7 @@ def test_generic_data_wrapper_verifier_failure(): class ListDataNoGenericsWrapper(ParametrizedAttribute): name = "list_no_generics_wrapper" - val: ParameterDef[ListData[Any]] + val: ParameterDef[AnyListData] def test_generic_data_no_generics_wrapper_verifier(): From 3de15ba90e065e54d23ce92118b794099f46a129 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Sat, 4 Jun 2022 19:28:41 +0200 Subject: [PATCH 36/36] Fix rebase --- src/xdsl/printer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/xdsl/printer.py b/src/xdsl/printer.py index 55a5f11847..94c037bf65 100644 --- a/src/xdsl/printer.py +++ b/src/xdsl/printer.py @@ -55,9 +55,6 @@ def print_string(self, text) -> None: self._current_column += len(lines[-1]) print(text, end='', file=self.stream) - def print_string(self, string: str) -> None: - self.print(string) - def _add_message_on_next_line(self, message: str, begin_pos: int, end_pos: int): """Add a message that will be displayed on the next line."""