diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 024dc257ba..34f7c08207 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -7,7 +7,7 @@ from xdsl.ir import (Block, Data, MLContext, MLIRType, ParametrizedAttribute, Operation, Region, Attribute, Dialect, SSAValue, - AttributeTCov, AttributeT) + AttributeCovT, AttributeInvT) from xdsl.irdl import (AllOf, OpAttr, VarOpResult, VarOperand, VarRegion, irdl_attr_definition, attr_constr_coercion, @@ -42,21 +42,21 @@ def verify(self, attr: Attribute) -> None: @irdl_attr_definition -class ArrayAttr(GenericData[tuple[AttributeTCov, ...]]): +class ArrayAttr(GenericData[tuple[AttributeCovT, ...]]): name: str = "array" - def __init__(self: ArrayAttr[AttributeTCov], - param: Iterable[AttributeTCov]) -> None: + def __init__(self: ArrayAttr[AttributeCovT], + param: Iterable[AttributeCovT]) -> None: super().__init__(tuple(param)) @staticmethod - def parse_parameter(parser: BaseParser) -> tuple[AttributeTCov]: + def parse_parameter(parser: BaseParser) -> tuple[AttributeCovT]: parser.parse_char("[") data = parser.parse_list_of(parser.try_parse_attribute, "Expected attribute") parser.parse_char("]") # the type system can't ensure that the elements are of type _ArrayAttrT - result = cast(tuple[AttributeTCov], tuple(data)) + result = cast(tuple[AttributeCovT], tuple(data)) return result def print_parameter(self, printer: Printer) -> None: @@ -87,8 +87,8 @@ def verify(self) -> None: @staticmethod @deprecated_constructor - def from_list(data: List[AttributeTCov]) -> ArrayAttr[AttributeTCov]: - return ArrayAttr[AttributeTCov](data) + def from_list(data: List[AttributeCovT]) -> ArrayAttr[AttributeCovT]: + return ArrayAttr[AttributeCovT](data) def __len__(self): return len(self.data) @@ -472,11 +472,11 @@ def from_type_list(types: List[Attribute]) -> TupleType: @irdl_attr_definition -class VectorType(Generic[AttributeT], ParametrizedAttribute, MLIRType): +class VectorType(Generic[AttributeInvT], ParametrizedAttribute, MLIRType): name = "vector" shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] - element_type: ParameterDef[AttributeT] + element_type: ParameterDef[AttributeInvT] def get_num_dims(self) -> int: return len(self.shape.data) @@ -486,9 +486,9 @@ def get_shape(self) -> List[int]: @staticmethod def from_element_type_and_shape( - referenced_type: AttributeT, - shape: List[int | IntegerAttr[IndexType]] - ) -> VectorType[AttributeT]: + referenced_type: AttributeInvT, + shape: List[int | IntegerAttr[IndexType]] + ) -> VectorType[AttributeInvT]: return VectorType([ ArrayAttr([ IntegerAttr[IntegerType].from_index_int_value(d) if isinstance( @@ -498,10 +498,10 @@ def from_element_type_and_shape( @staticmethod def from_params( - referenced_type: AttributeT, + referenced_type: AttributeInvT, shape: ArrayAttr[IntegerAttr[IntegerType]] = ArrayAttr( [IntegerAttr.from_int_and_width(1, 64)]) - ) -> VectorType[AttributeT]: + ) -> VectorType[AttributeInvT]: return VectorType([shape, referenced_type]) @@ -509,11 +509,11 @@ def from_params( @irdl_attr_definition -class TensorType(Generic[AttributeTCov], ParametrizedAttribute, MLIRType): +class TensorType(Generic[AttributeCovT], ParametrizedAttribute, MLIRType): name = "tensor" shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] - element_type: ParameterDef[AttributeTCov] + element_type: ParameterDef[AttributeCovT] def get_num_dims(self) -> int: return len(self.shape.data) @@ -523,9 +523,9 @@ def get_shape(self) -> List[int]: @staticmethod def from_type_and_list( - referenced_type: AttributeT, + referenced_type: AttributeInvT, shape: Sequence[int | IntegerAttr[IndexType]] | None = None - ) -> TensorType[AttributeT]: + ) -> TensorType[AttributeInvT]: if shape is None: shape = [1] return TensorType([ @@ -537,10 +537,10 @@ def from_type_and_list( @staticmethod def from_params( - referenced_type: AttributeT, + referenced_type: AttributeInvT, shape: AnyArrayAttr = AnyArrayAttr( [IntegerAttr.from_int_and_width(1, 64)]) - ) -> TensorType[AttributeT]: + ) -> TensorType[AttributeInvT]: return TensorType([shape, referenced_type]) @@ -548,15 +548,16 @@ def from_params( @irdl_attr_definition -class UnrankedTensorType(Generic[AttributeTCov], ParametrizedAttribute, +class UnrankedTensorType(Generic[AttributeCovT], ParametrizedAttribute, MLIRType): name = "unranked_tensor" - element_type: ParameterDef[AttributeTCov] + element_type: ParameterDef[AttributeCovT] @staticmethod def from_type( - referenced_type: AttributeT) -> UnrankedTensorType[AttributeT]: + referenced_type: AttributeInvT + ) -> UnrankedTensorType[AttributeInvT]: return UnrankedTensorType([referenced_type]) @@ -580,9 +581,9 @@ def verify(self, attr: Attribute) -> None: self.elem_constr.verify(attr) -VectorOrTensorOf: TypeAlias = (VectorType[AttributeT] - | TensorType[AttributeT] - | UnrankedTensorType[AttributeT]) +VectorOrTensorOf: TypeAlias = (VectorType[AttributeInvT] + | TensorType[AttributeInvT] + | UnrankedTensorType[AttributeInvT]) @dataclass diff --git a/xdsl/ir.py b/xdsl/ir.py index 2f8c7eaede..ec10591673 100644 --- a/xdsl/ir.py +++ b/xdsl/ir.py @@ -317,8 +317,8 @@ def __str__(self) -> str: _D = TypeVar("_D", bound="Data[Any]") -AttributeTCov = TypeVar("AttributeTCov", bound=Attribute, covariant=True) -AttributeT = TypeVar("AttributeT", bound=Attribute) +AttributeCovT = TypeVar("AttributeCovT", bound=Attribute, covariant=True) +AttributeInvT = TypeVar("AttributeInvT", bound=Attribute) @dataclass(frozen=True)