Skip to content

Commit

Permalink
updating naming of Attribute Types
Browse files Browse the repository at this point in the history
  • Loading branch information
Fergtic committed Mar 12, 2023
1 parent 447e40c commit bc05f99
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
55 changes: 28 additions & 27 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from xdsl.ir import (Block, Data, MLContext, MLIRType, ParametrizedAttribute,
Operation, Region, Attribute, Dialect, SSAValue,
AttributeTCov, AttributeT)
AttributeCovT, AttributeInvT)

from xdsl.irdl import (AllOf, OpAttr, VarOpResult, VarOperand, VarRegion,
irdl_attr_definition, attr_constr_coercion,
Expand Down Expand Up @@ -42,21 +42,21 @@ def verify(self, attr: Attribute) -> None:


@irdl_attr_definition
class ArrayAttr(GenericData[tuple[AttributeTCov, ...]]):
class ArrayAttr(GenericData[tuple[AttributeCovT, ...]]):
name: str = "array"

def __init__(self: ArrayAttr[AttributeTCov],
param: Iterable[AttributeTCov]) -> None:
def __init__(self: ArrayAttr[AttributeCovT],
param: Iterable[AttributeCovT]) -> None:
super().__init__(tuple(param))

@staticmethod
def parse_parameter(parser: BaseParser) -> tuple[AttributeTCov]:
def parse_parameter(parser: BaseParser) -> tuple[AttributeCovT]:
parser.parse_char("[")
data = parser.parse_list_of(parser.try_parse_attribute,
"Expected attribute")
parser.parse_char("]")
# the type system can't ensure that the elements are of type _ArrayAttrT
result = cast(tuple[AttributeTCov], tuple(data))
result = cast(tuple[AttributeCovT], tuple(data))
return result

def print_parameter(self, printer: Printer) -> None:
Expand Down Expand Up @@ -87,8 +87,8 @@ def verify(self) -> None:

@staticmethod
@deprecated_constructor
def from_list(data: List[AttributeTCov]) -> ArrayAttr[AttributeTCov]:
return ArrayAttr[AttributeTCov](data)
def from_list(data: List[AttributeCovT]) -> ArrayAttr[AttributeCovT]:
return ArrayAttr[AttributeCovT](data)

def __len__(self):
return len(self.data)
Expand Down Expand Up @@ -472,11 +472,11 @@ def from_type_list(types: List[Attribute]) -> TupleType:


@irdl_attr_definition
class VectorType(Generic[AttributeT], ParametrizedAttribute, MLIRType):
class VectorType(Generic[AttributeInvT], ParametrizedAttribute, MLIRType):
name = "vector"

shape: ParameterDef[ArrayAttr[AnyIntegerAttr]]
element_type: ParameterDef[AttributeT]
element_type: ParameterDef[AttributeInvT]

def get_num_dims(self) -> int:
return len(self.shape.data)
Expand All @@ -486,9 +486,9 @@ def get_shape(self) -> List[int]:

@staticmethod
def from_element_type_and_shape(
referenced_type: AttributeT,
shape: List[int | IntegerAttr[IndexType]]
) -> VectorType[AttributeT]:
referenced_type: AttributeInvT,
shape: List[int | IntegerAttr[IndexType]]
) -> VectorType[AttributeInvT]:
return VectorType([
ArrayAttr([
IntegerAttr[IntegerType].from_index_int_value(d) if isinstance(
Expand All @@ -498,22 +498,22 @@ def from_element_type_and_shape(

@staticmethod
def from_params(
referenced_type: AttributeT,
referenced_type: AttributeInvT,
shape: ArrayAttr[IntegerAttr[IntegerType]] = ArrayAttr(
[IntegerAttr.from_int_and_width(1, 64)])
) -> VectorType[AttributeT]:
) -> VectorType[AttributeInvT]:
return VectorType([shape, referenced_type])


AnyVectorType: TypeAlias = VectorType[Attribute]


@irdl_attr_definition
class TensorType(Generic[AttributeTCov], ParametrizedAttribute, MLIRType):
class TensorType(Generic[AttributeCovT], ParametrizedAttribute, MLIRType):
name = "tensor"

shape: ParameterDef[ArrayAttr[AnyIntegerAttr]]
element_type: ParameterDef[AttributeTCov]
element_type: ParameterDef[AttributeCovT]

def get_num_dims(self) -> int:
return len(self.shape.data)
Expand All @@ -523,9 +523,9 @@ def get_shape(self) -> List[int]:

@staticmethod
def from_type_and_list(
referenced_type: AttributeT,
referenced_type: AttributeInvT,
shape: Sequence[int | IntegerAttr[IndexType]] | None = None
) -> TensorType[AttributeT]:
) -> TensorType[AttributeInvT]:
if shape is None:
shape = [1]
return TensorType([
Expand All @@ -537,26 +537,27 @@ def from_type_and_list(

@staticmethod
def from_params(
referenced_type: AttributeT,
referenced_type: AttributeInvT,
shape: AnyArrayAttr = AnyArrayAttr(
[IntegerAttr.from_int_and_width(1, 64)])
) -> TensorType[AttributeT]:
) -> TensorType[AttributeInvT]:
return TensorType([shape, referenced_type])


AnyTensorType: TypeAlias = TensorType[Attribute]


@irdl_attr_definition
class UnrankedTensorType(Generic[AttributeTCov], ParametrizedAttribute,
class UnrankedTensorType(Generic[AttributeCovT], ParametrizedAttribute,
MLIRType):
name = "unranked_tensor"

element_type: ParameterDef[AttributeTCov]
element_type: ParameterDef[AttributeCovT]

@staticmethod
def from_type(
referenced_type: AttributeT) -> UnrankedTensorType[AttributeT]:
referenced_type: AttributeInvT
) -> UnrankedTensorType[AttributeInvT]:
return UnrankedTensorType([referenced_type])


Expand All @@ -580,9 +581,9 @@ def verify(self, attr: Attribute) -> None:
self.elem_constr.verify(attr)


VectorOrTensorOf: TypeAlias = (VectorType[AttributeT]
| TensorType[AttributeT]
| UnrankedTensorType[AttributeT])
VectorOrTensorOf: TypeAlias = (VectorType[AttributeInvT]
| TensorType[AttributeInvT]
| UnrankedTensorType[AttributeInvT])


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions xdsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def __str__(self) -> str:

_D = TypeVar("_D", bound="Data[Any]")

AttributeTCov = TypeVar("AttributeTCov", bound=Attribute, covariant=True)
AttributeT = TypeVar("AttributeT", bound=Attribute)
AttributeCovT = TypeVar("AttributeCovT", bound=Attribute, covariant=True)
AttributeInvT = TypeVar("AttributeInvT", bound=Attribute)


@dataclass(frozen=True)
Expand Down

0 comments on commit bc05f99

Please sign in to comment.