Skip to content

Commit

Permalink
dialects: (bufferization) Simplify tensor/memref constraint (xdslproj…
Browse files Browse the repository at this point in the history
…ect#3475)

Removes the `TensorMemrefInferenceConstraint` in favour of a simpler
`TensorFromMemrefConstraint`.

---------

Co-authored-by: Sasha Lopoukhine <superlopuh@gmail.com>
  • Loading branch information
2 people authored and EdmundGoodman committed Dec 6, 2024
1 parent aa43761 commit 287d694
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 77 deletions.
137 changes: 135 additions & 2 deletions tests/dialects/test_bufferization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,139 @@
from xdsl.dialects.bufferization import AllocTensorOp, ToTensorOp
from xdsl.dialects.builtin import MemRefType, TensorType, UnitAttr, f64
from typing import ClassVar

import pytest

from xdsl.dialects.bufferization import (
AllocTensorOp,
TensorFromMemrefConstraint,
ToTensorOp,
)
from xdsl.dialects.builtin import (
AnyMemRefTypeConstr,
AnyUnrankedMemrefTypeConstr,
IndexType,
IntegerType,
MemRefType,
TensorType,
UnitAttr,
UnrankedMemrefType,
UnrankedTensorType,
f64,
)
from xdsl.dialects.test import TestOp
from xdsl.ir import Attribute
from xdsl.irdl import (
ConstraintContext,
EqAttrConstraint,
IRDLOperation,
VarConstraint,
irdl_op_definition,
operand_def,
)
from xdsl.utils.exceptions import VerifyException


def test_tensor_from_memref_inference():
constr = TensorFromMemrefConstraint(AnyMemRefTypeConstr)
assert not constr.can_infer(set())

constr2 = TensorFromMemrefConstraint(
EqAttrConstraint(MemRefType(f64, [10, 20, 30]))
)
assert constr2.can_infer(set())
assert constr2.infer(ConstraintContext()) == TensorType(f64, [10, 20, 30])

constr3 = TensorFromMemrefConstraint(
EqAttrConstraint(UnrankedMemrefType.from_type(f64))
)
assert constr3.can_infer(set())
assert constr3.infer(ConstraintContext()) == UnrankedTensorType(f64)


@irdl_op_definition
class TensorFromMemref(IRDLOperation):
name = "test.tensor_from_memref"
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)

in_tensor = operand_def(
TensorFromMemrefConstraint(
MemRefType.constr(element_type=EqAttrConstraint(IndexType()))
)
)

in_var_memref = operand_def(T)

in_var_tensor = operand_def(TensorFromMemrefConstraint(T))


def test_tensor_from_memref_constraint():
[v_memref, v_tensor] = TestOp(
result_types=[
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
]
).res
op1 = TensorFromMemref(operands=(v_tensor, v_memref, v_tensor))
op1.verify()

[v_unranked_memref, v_unranked_tensor] = TestOp(
result_types=[
UnrankedMemrefType.from_type(IndexType()),
UnrankedTensorType(IndexType()),
]
).res
op2 = TensorFromMemref(operands=(v_tensor, v_unranked_memref, v_unranked_tensor))
op2.verify()


@pytest.mark.parametrize(
"type1, type2, type3, error",
[
(
IndexType(),
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
"Expected tensor or unranked tensor type, got index",
),
(
TensorType(IntegerType(32), [10, 10, 10]),
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
"Expected attribute index but got i32",
),
(
UnrankedTensorType(IndexType()),
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
"memref<\\*xindex> should be of base attribute memref",
),
(
TensorType(IndexType(), [10, 10, 10]),
MemRefType(IndexType(), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 20]),
"attribute memref<10x20x30xindex> expected from variable 'T', but got memref<10x20x20xindex>",
),
(
TensorType(IndexType(), [10, 10, 10]),
MemRefType(IntegerType(32), [10, 20, 30]),
TensorType(IndexType(), [10, 20, 30]),
"attribute memref<10x20x30xi32> expected from variable 'T', but got memref<10x20x30xindex>",
),
],
)
def test_tensor_from_memref_constraint_failure(
type1: Attribute, type2: Attribute, type3: Attribute, error: str
):
[v1, v2, v3] = TestOp(
result_types=[
type1,
type2,
type3,
]
).res

op1 = TensorFromMemref(operands=(v1, v2, v3))
with pytest.raises(VerifyException, match=error):
op1.verify()


def test_to_tensor():
Expand Down
120 changes: 45 additions & 75 deletions xdsl/dialects/bufferization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, ClassVar

from xdsl.dialects.builtin import (
Expand All @@ -17,9 +18,9 @@
)
from xdsl.ir import Attribute, Dialect, Operation, SSAValue
from xdsl.irdl import (
AnyOf,
AttrSizedOperandSegments,
ConstraintContext,
GenericAttrConstraint,
IRDLOperation,
VarConstraint,
irdl_op_definition,
Expand All @@ -33,51 +34,41 @@
from xdsl.utils.hints import isa


class TensorMemrefInferenceConstraint(VarConstraint[Attribute]):
@dataclass(frozen=True)
class TensorFromMemrefConstraint(
GenericAttrConstraint[TensorType[Attribute] | UnrankedTensorType[Attribute]]
):
"""
Constraint to infer tensor shapes from memref shapes, inferring ranked tensor from ranked memref
(and unranked from unranked, respectively).
Verification checks that attributes of the same variable name are either all ranked or all unranked,
and checks for matching element type, shape (ranked only), as well as verifying sub constraints.
Converts an input memref constraint to the corresponding tensor constraint, i.e. the constraints
on element type and shape are the same as the input constraint, but the attribute is verified to be
a tensor instead of a memref.
"""

memref_constraint: GenericAttrConstraint[
MemRefType[Attribute] | UnrankedMemrefType[Attribute]
]

def can_infer(self, constraint_names: set[str]) -> bool:
return self.memref_constraint.can_infer(constraint_names)

def infer(self, constraint_context: ConstraintContext) -> Attribute:
if self.name in constraint_context.variables:
m_type = constraint_context.get_variable(self.name)
if isa(m_type, MemRefType[Attribute]):
return TensorType(m_type.get_element_type(), m_type.get_shape())
if isa(m_type, UnrankedMemrefType[Attribute]):
return UnrankedTensorType(m_type.element_type)
raise ValueError(f"Unexpected {self.name} - cannot infer attribute")
memref_type = self.memref_constraint.infer(constraint_context)
if isa(memref_type, MemRefType[Attribute]):
return TensorType(memref_type.element_type, memref_type.shape)
assert isa(memref_type, UnrankedMemrefType[Attribute])
return UnrankedTensorType(memref_type.element_type)

def get_resolved_variables(self) -> set[str]:
return self.memref_constraint.get_resolved_variables()

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if self.name in constraint_context.variables:
seen = constraint_context.get_variable(self.name)
if not (
isinstance(attr, ContainerType)
and isinstance(seen, ContainerType)
and attr.get_element_type() == seen.get_element_type()
):
raise VerifyException(
f"Unexpected {self.name} - cannot verify element type of attribute {attr}"
)
if (
isinstance(attr, ShapedType) != isinstance(seen, ShapedType)
or isinstance(attr, ShapedType)
and isinstance(seen, ShapedType)
and attr.get_shape() != seen.get_shape()
):
raise VerifyException(
f"Unexpected {self.name} - cannot verify shape of attribute {attr}"
)
elif isinstance(attr, ContainerType):
self.constraint.verify(attr, constraint_context)
constraint_context.set_variable(self.name, attr)
else:
raise VerifyException(
f"Unexpected {self.name} - attribute must be ContainerType"
)
if isa(attr, TensorType[Attribute]):
memref_type = MemRefType(attr.element_type, attr.shape)
return self.memref_constraint.verify(memref_type, constraint_context)
if isa(attr, UnrankedTensorType[Attribute]):
memref_type = UnrankedMemrefType.from_type(attr.element_type)
return self.memref_constraint.verify(memref_type, constraint_context)
raise VerifyException(f"Expected tensor or unranked tensor type, got {attr}")


@irdl_op_definition
Expand Down Expand Up @@ -147,16 +138,11 @@ def __init__(
class ToTensorOp(IRDLOperation):
name = "bufferization.to_tensor"

memref = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyMemRefTypeConstr, AnyUnrankedMemrefTypeConstr])
)
)
tensor = result_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr])
)
)
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)

memref = operand_def(T)
tensor = result_def(TensorFromMemrefConstraint(T))

writable = opt_prop_def(UnitAttr)
restrict = opt_prop_def(UnitAttr)

Expand Down Expand Up @@ -192,16 +178,10 @@ def __init__(
class ToMemrefOp(IRDLOperation):
name = "bufferization.to_memref"

tensor = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyTensorTypeConstr, AnyUnrankedTensorTypeConstr])
)
)
memref = result_def(
TensorMemrefInferenceConstraint(
"T", AnyOf([AnyMemRefTypeConstr, AnyUnrankedMemrefTypeConstr])
)
)
T: ClassVar = VarConstraint("T", AnyMemRefTypeConstr | AnyUnrankedMemrefTypeConstr)
tensor = operand_def(TensorFromMemrefConstraint(T))
memref = result_def(T)

read_only = opt_prop_def(UnitAttr)

assembly_format = "$tensor (`read_only` $read_only^)? `:` attr-dict type($memref)"
Expand All @@ -211,21 +191,11 @@ class ToMemrefOp(IRDLOperation):
class MaterializeInDestination(IRDLOperation):
name = "bufferization.materialize_in_destination"

source = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr
)
)
dest = operand_def(
TensorMemrefInferenceConstraint(
"T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr
)
)
result = result_def(
TensorMemrefInferenceConstraint(
"T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr
)
)
T: ClassVar = VarConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr)
source = operand_def(T)
dest = operand_def(T)
result = result_def(T)

restrict = opt_prop_def(UnitAttr)
writable = opt_prop_def(UnitAttr)

Expand Down

0 comments on commit 287d694

Please sign in to comment.