Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (bufferization) Simplify tensor/memref constraint #3475

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 45 additions & 75 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, ClassVar

from xdsl.dialects.builtin import (
Expand All @@ -17,9 +18,9 @@
)
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AnyOf,
AttrSizedOperandSegments,
ConstraintContext,
GenericAttrConstraint,
IRDLOperation,
VarConstraint,
irdl_op_definition,
Expand All @@ -33,51 +34,41 @@
from xdsl.utils.hints import isa


class TensorMemrefInferenceConstraint(VarConstraint[Attribute]):
@dataclass(frozen=True)
class TensorFromMemrefConstraint(
GenericAttrConstraint[TensorType[Attribute] | UnrankedTensorType[Attribute]]
):
"""
Constraint to infer tensor shapes from memref shapes, inferring ranked tensor from ranked memref
(and unranked from unranked, respectively).

Verification checks that attributes of the same variable name are either all ranked or all unranked,
and checks for matching element type, shape (ranked only), as well as verifying sub constraints.
Converts an input memref constraint to the corresponding tensor constraint, i.e. the constraints
on element type and shape are the same as the input constraint, but the attribute is verified to be
a tensor instead of a memref.
"""

memref_constraint: GenericAttrConstraint[
MemRefType[Attribute] | UnrankedMemrefType[Attribute]
]

def can_infer(self, constraint_names: set[str]) -> bool:
return self.memref_constraint.can_infer(constraint_names)

def infer(self, constraint_context: ConstraintContext) -> Attribute:
if self.name in constraint_context.variables:
m_type = constraint_context.get_variable(self.name)
if isa(m_type, MemRefType[Attribute]):
return TensorType(m_type.get_element_type(), m_type.get_shape())
if isa(m_type, UnrankedMemrefType[Attribute]):
return UnrankedTensorType(m_type.element_type)
raise ValueError(f"Unexpected {self.name} - cannot infer attribute")
memref_type = self.memref_constraint.infer(constraint_context)
if isa(memref_type, MemRefType[Attribute]):
return TensorType(memref_type.element_type, memref_type.shape)
assert isa(memref_type, UnrankedMemrefType[Attribute])
return UnrankedTensorType(memref_type.element_type)

def get_resolved_variables(self) -> set[str]:
return self.memref_constraint.get_resolved_variables()

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if self.name in constraint_context.variables:
seen = constraint_context.get_variable(self.name)
if not (
isinstance(attr, ContainerType)
and isinstance(seen, ContainerType)
and attr.get_element_type() == seen.get_element_type()
):
raise VerifyException(
f"Unexpected {self.name} - cannot verify element type of attribute {attr}"
)
if (
isinstance(attr, ShapedType) != isinstance(seen, ShapedType)
or isinstance(attr, ShapedType)
and isinstance(seen, ShapedType)
and attr.get_shape() != seen.get_shape()
):
raise VerifyException(
f"Unexpected {self.name} - cannot verify shape of attribute {attr}"
)
elif isinstance(attr, ContainerType):
self.constraint.verify(attr, constraint_context)
constraint_context.set_variable(self.name, attr)
else:
raise VerifyException(
f"Unexpected {self.name} - attribute must be ContainerType"
)
if isa(attr, TensorType[Attribute]):
memref_type = MemRefType(attr.element_type, attr.shape)
return self.memref_constraint.verify(memref_type, constraint_context)
if isa(attr, UnrankedTensorType[Attribute]):
memref_type = UnrankedMemrefType.from_type(attr.element_type)
return self.memref_constraint.verify(memref_type, constraint_context)
raise VerifyException(f"Expceted TensorType or UnrankedTensorType, got {attr}")


@irdl_op_definition
Expand Down Expand Up @@ -147,16 +138,11 @@ def __init__(
class ToTensorOp(IRDLOperation):
name = "bufferization.to_tensor"

memref = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyMemRefTypeConstr, AnyUnrankedMemrefTypeConstr])
)
)
tensor = result_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr])
)
)
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)

memref = operand_def(T)
tensor = result_def(TensorFromMemrefConstraint(T))

writable = opt_prop_def(UnitAttr)
restrict = opt_prop_def(UnitAttr)

Expand Down Expand Up @@ -192,16 +178,10 @@ def __init__(
class ToMemrefOp(IRDLOperation):
name = "bufferization.to_memref"

tensor = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr])
)
)
memref = result_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyMemRefTypeConstr, AnyUnrankedMemrefTypeConstr])
)
)
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
tensor = operand_def(TensorFromMemrefConstraint(T))
memref = result_def(T)

read_only = opt_prop_def(UnitAttr)

assembly_format = "$tensor (`read_only` $read_only^)? `:` attr-dict type($memref)"
Expand All @@ -211,21 +191,11 @@ class ToMemrefOp(IRDLOperation):
class MaterializeInDestination(IRDLOperation):
name = "bufferization.materialize_in_destination"

source = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr
)
)
dest = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr
)
)
result = result_def(
TensorMemrefInferenceConstraint(
"T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr
)
)
T: ClassVar = VarConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr)
source = operand_def(T)
dest = operand_def(T)
result = result_def(T)

restrict = opt_prop_def(UnitAttr)
writable = opt_prop_def(UnitAttr)

Expand Down
Loading