Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

api: remove parameters field from Trait #2940

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/test_traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def test_has_trait_object():
"""
assert TestOp.has_trait(LargerOperandTrait)
assert not TestOp.has_trait(LargerResultTrait)
assert not TestOp.has_trait(BitwidthSumLessThanTrait, 0)
assert TestOp.has_trait(BitwidthSumLessThanTrait, 64)
assert not TestOp.has_trait(BitwidthSumLessThanTrait(0))
assert TestOp.has_trait(BitwidthSumLessThanTrait(64))


def test_get_traits_of_type():
Expand Down Expand Up @@ -445,8 +445,8 @@ def test_lazy_parent():
"""Test the trait infrastructure for an operation that defines a trait "lazily"."""
op = HasLazyParentOp.create()
assert len(op.get_traits_of_type(HasParent)) != 0
assert op.get_traits_of_type(HasParent)[0].parameters == (TestOp,)
assert op.has_trait(HasParent, (TestOp,))
assert op.get_traits_of_type(HasParent)[0].op_types == (TestOp,)
assert op.has_trait(HasParent(TestOp))
assert op.traits == frozenset([HasParent(TestOp)])


Expand All @@ -461,7 +461,7 @@ def test_has_ancestor():
op = AncestorOp()

assert op.get_traits_of_type(HasAncestor) == [HasAncestor(TestOp)]
assert op.has_trait(HasAncestor, (TestOp,))
assert op.has_trait(HasAncestor(TestOp))

with pytest.raises(
VerifyException, match="'test.ancestor' expects ancestor op 'test.test'"
Expand Down
11 changes: 6 additions & 5 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from abc import ABC
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import KW_ONLY, dataclass, field
from typing import Annotated, ClassVar, TypeAlias

from xdsl.dialects import builtin
Expand Down Expand Up @@ -186,14 +186,15 @@ class InModuleKind(OpTrait):
Ops with this trait are always allowed inside a csl_wrapper.module
"""

def __init__(self, kind: ModuleKind, *, direct_child: bool = True):
super().__init__((kind, direct_child))
kind: ModuleKind = field()
_: KW_ONLY
direct_child: bool = field(default=True)

def verify(self, op: Operation) -> None:
from xdsl.dialects.csl import csl_wrapper

kind: ModuleKind = self.parameters[0]
direct_child: bool = self.parameters[1]
kind: ModuleKind = self.kind
direct_child: bool = self.direct_child

direct = "direct" if direct_child else "indirect"
parent_module = op.parent_op()
Expand Down
5 changes: 3 additions & 2 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def verify_(self) -> None:

# As promised by HasAncestor(ApplyOp)
trait = cast(
HasAncestor, AccessOp.get_trait(HasAncestor, (stencil.ApplyOp, ApplyOp))
HasAncestor, AccessOp.get_trait(HasAncestor(stencil.ApplyOp, ApplyOp))
)
apply = trait.get_ancestor(self)
assert isinstance(apply, stencil.ApplyOp | ApplyOp)
Expand Down Expand Up @@ -529,7 +529,8 @@ def get_apply(self) -> stencil.ApplyOp | ApplyOp:
Simple helper to get the parent apply and raise otherwise.
"""
trait = cast(
HasAncestor, self.get_trait(HasAncestor, (stencil.ApplyOp, ApplyOp))
HasAncestor,
self.get_trait(HasAncestor(stencil.ApplyOp, ApplyOp)),
)
ancestor = trait.get_ancestor(self)
if ancestor is None:
Expand Down
14 changes: 9 additions & 5 deletions xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,8 @@ def get_apply(self):
"""
Simple helper to get the parent apply and raise otherwise.
"""
trait = cast(HasAncestor, self.get_trait(HasAncestor, (ApplyOp,)))
trait = self.get_trait(HasAncestor(ApplyOp))
assert trait is not None
ancestor = trait.get_ancestor(self)
if ancestor is None:
raise ValueError(
Expand Down Expand Up @@ -861,8 +862,9 @@ def print(self, printer: Printer):
print_keyword=True,
)

# IRDL-enforced, not supposed to use custom syntax if not veriied
trait = cast(HasAncestor, AccessOp.get_trait(HasAncestor, (ApplyOp,)))
# IRDL-enforced, not supposed to use custom syntax if not verified
trait = AccessOp.get_trait(HasAncestor(ApplyOp))
assert trait is not None
apply = cast(ApplyOp, trait.get_ancestor(self))

mapping = self.offset_mapping
Expand Down Expand Up @@ -949,7 +951,8 @@ def get(

def verify_(self) -> None:
# As promised by HasAncestor(ApplyOp)
trait = cast(HasAncestor, AccessOp.get_trait(HasAncestor, (ApplyOp,)))
trait = AccessOp.get_trait(HasAncestor(ApplyOp))
assert trait is not None
apply = trait.get_ancestor(self)
assert isinstance(apply, ApplyOp)

Expand Down Expand Up @@ -1002,7 +1005,8 @@ def get_apply(self):
"""
Simple helper to get the parent apply and raise otherwise.
"""
trait = cast(HasAncestor, self.get_trait(HasAncestor, (ApplyOp,)))
trait = self.get_trait(HasAncestor(ApplyOp))
assert trait is not None
ancestor = trait.get_ancestor(self)
if ancestor is None:
raise ValueError(
Expand Down
21 changes: 12 additions & 9 deletions xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,8 +939,8 @@ def clone(
@classmethod
def has_trait(
cls,
trait: type[OpTrait],
parameters: Any = None,
trait: type[OpTrait] | OpTrait,
*,
value_if_unregistered: bool = True,
) -> bool:
"""
Expand All @@ -953,18 +953,21 @@ def has_trait(
if issubclass(cls, UnregisteredOp):
return value_if_unregistered

return cls.get_trait(trait, parameters) is not None
return cls.get_trait(trait) is not None

@classmethod
def get_trait(
cls, trait: type[OpTraitInvT], parameters: Any = None
) -> OpTraitInvT | None:
def get_trait(cls, trait: type[OpTraitInvT] | OpTraitInvT) -> OpTraitInvT | None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change, but the fix is just a couple of backspaces to delete the comma and instantiate the class instead of passing the parameters separately

"""
Return a trait with the given type and parameters, if it exists.
"""
for t in cls.traits:
if isinstance(t, trait) and t.parameters == parameters:
return t
if isinstance(trait, type):
for t in cls.traits:
if isinstance(t, trait):
return t
else:
for t in cls.traits:
if t == trait:
return cast(OpTraitInvT, t)
return None

@classmethod
Expand Down
51 changes: 23 additions & 28 deletions xdsl/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import abc
from collections.abc import Iterator
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, TypeVar, final
from typing import TYPE_CHECKING, TypeVar, final

from xdsl.utils.exceptions import VerifyException

Expand All @@ -20,13 +20,9 @@ class OpTrait:
A trait attached to an operation definition.
Traits can be used to define operation invariants, additional semantic information,
or to group operations that have similar properties.
Traits have parameters, which by default is just the `None` value. Parameters should
always be comparable and hashable.
Note that traits are the merge of traits and interfaces in MLIR.
"""

parameters: Any = field(default=None)

def verify(self, op: Operation) -> None:
"""Check that the operation satisfies the trait requirements."""
pass
Expand All @@ -47,22 +43,20 @@ class ConstantLike(OpTrait):
class HasParent(OpTrait):
"""Constraint the operation to have a specific parent operation."""

parameters: tuple[type[Operation], ...]
op_types: tuple[type[Operation], ...]

def __init__(self, *parameters: type[Operation]):
if not parameters:
raise ValueError("parameters must not be empty")
super().__init__(parameters)
def __init__(self, head_param: type[Operation], *tail_params: type[Operation]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not super() init?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no super init here that would set the op types property, as far as I understand

object.__setattr__(self, "op_types", (head_param, *tail_params))

def verify(self, op: Operation) -> None:
parent = op.parent_op()
if isinstance(parent, self.parameters):
if isinstance(parent, self.op_types):
return
if len(self.parameters) == 1:
if len(self.op_types) == 1:
raise VerifyException(
f"'{op.name}' expects parent op '{self.parameters[0].name}'"
f"'{op.name}' expects parent op '{self.op_types[0].name}'"
)
names = ", ".join(f"'{p.name}'" for p in self.parameters)
names = ", ".join(f"'{p.name}'" for p in self.op_types)
raise VerifyException(f"'{op.name}' expects parent op to be one of {names}")


Expand All @@ -73,18 +67,18 @@ class HasAncestor(OpTrait):
parent.
"""

parameters: tuple[type[Operation], ...]
op_types: tuple[type[Operation], ...]

def __init__(self, head_param: type[Operation], *tail_params: type[Operation]):
super().__init__((head_param, *tail_params))
object.__setattr__(self, "op_types", (head_param, *tail_params))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


def verify(self, op: Operation) -> None:
if self.get_ancestor(op) is None:
if len(self.parameters) == 1:
if len(self.op_types) == 1:
raise VerifyException(
f"'{op.name}' expects ancestor op '{self.parameters[0].name}'"
f"'{op.name}' expects ancestor op '{self.op_types[0].name}'"
)
names = ", ".join(f"'{p.name}'" for p in self.parameters)
names = ", ".join(f"'{p.name}'" for p in self.op_types)
raise VerifyException(
f"'{op.name}' expects ancestor op to be one of {names}"
)
Expand All @@ -98,7 +92,7 @@ def walk_ancestors(self, op: Operation) -> Iterator[Operation]:

def get_ancestor(self, op: Operation) -> Operation | None:
ancestors = self.walk_ancestors(op)
matching_ancestors = (a for a in ancestors if isinstance(a, self.parameters))
matching_ancestors = (a for a in ancestors if isinstance(a, self.op_types))
return next(matching_ancestors, None)


Expand Down Expand Up @@ -133,6 +127,7 @@ def verify(self, op: Operation) -> None:
)


@dataclass(frozen=True)
class SingleBlockImplicitTerminator(OpTrait):
"""
Checks the existence of the specified terminator to an operation which has
Expand All @@ -144,7 +139,7 @@ class SingleBlockImplicitTerminator(OpTrait):
https://mlir.llvm.org/docs/Traits/#single-block-with-implicit-terminator
"""

parameters: type[Operation]
op_type: type[Operation]

def verify(self, op: Operation) -> None:
for region in op.regions:
Expand All @@ -156,13 +151,13 @@ def verify(self, op: Operation) -> None:
if (last_op := block.last_op) is None:
raise VerifyException(
f"'{op.name}' contains empty block instead of at least "
f"terminating with {self.parameters.name}"
f"terminating with {self.op_type.name}"
)

if not isinstance(last_op, self.parameters):
if not isinstance(last_op, self.op_type):
raise VerifyException(
f"'{op.name}' terminates with operation {last_op.name} "
f"instead of {self.parameters.name}"
f"instead of {self.op_type.name}"
)


Expand All @@ -181,11 +176,11 @@ def ensure_terminator(op: Operation, trait: SingleBlockImplicitTerminator) -> No
if (
(last_op := block.last_op) is not None
and last_op.has_trait(IsTerminator)
and not isinstance(last_op, trait.parameters)
and not isinstance(last_op, trait.op_type)
):
raise VerifyException(
f"'{op.name}' terminates with operation {last_op.name} "
f"instead of {trait.parameters.name}"
f"instead of {trait.op_type.name}"
)

from xdsl.builder import ImplicitBuilder
Expand All @@ -200,7 +195,7 @@ def ensure_terminator(op: Operation, trait: SingleBlockImplicitTerminator) -> No
IsTerminator
):
with ImplicitBuilder(block):
trait.parameters.create()
trait.op_type.create()


class IsolatedFromAbove(OpTrait):
Expand Down
Loading