Skip to content

Commit

Permalink
dialects: (llvm) added disjoint flag (#3428)
Browse files Browse the repository at this point in the history
  • Loading branch information
lfrenot authored Nov 12, 2024
1 parent 265f59f commit 615cd54
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
19 changes: 18 additions & 1 deletion tests/dialects/test_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
(llvm.URemOp, {}),
(llvm.SRemOp, {}),
(llvm.AndOp, {}),
(llvm.OrOp, {}),
(llvm.XOrOp, {}),
],
)
Expand Down Expand Up @@ -73,6 +72,24 @@ def test_llvm_exact_arithmetic_ops(
)


@pytest.mark.parametrize(
"op_type, attributes, disjoint",
[
(llvm.OrOp, {}, llvm.UnitAttr()),
(llvm.OrOp, {}, None),
],
)
def test_llvm_disjoint_arithmetic_ops(
op_type: type[llvm.ArithmeticBinOpDisjoint],
attributes: dict[str, Attribute],
disjoint: llvm.UnitAttr | None,
):
op1, op2 = test.TestOp(result_types=[i32, i32]).results
assert op_type(op1, op2, attributes, disjoint).is_structurally_equivalent(
op_type(lhs=op1, rhs=op2, attributes=attributes, is_disjoint=disjoint)
)


def test_llvm_pointer_ops():
module = builtin.ModuleOp(
[
Expand Down
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/llvm/arithmetic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
%or = llvm.or %arg0, %arg1 : i32
// CHECK: %or = llvm.or %arg0, %arg1 : i32

%or_disjoint = llvm.or disjoint %arg0, %arg1 : i32
// CHECK: %or_disjoint = llvm.or disjoint %arg0, %arg1 : i32

%xor = llvm.xor %arg0, %arg1 : i32
// CHECK: %xor = llvm.xor %arg0, %arg1 : i32

Expand Down
35 changes: 34 additions & 1 deletion xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,39 @@ def print(self, printer: Printer) -> None:
printer.print(self.lhs.type)


class ArithmeticBinOpDisjoint(IRDLOperation, ABC):
"""Class for arithmetic binary operations that use a disjoint flag."""

T: ClassVar = VarConstraint("T", BaseAttr(IntegerType))

lhs = operand_def(T)
rhs = operand_def(T)
res = result_def(T)
is_disjoint = opt_prop_def(UnitAttr, prop_name="isDisjoint")

traits = traits_def(NoMemoryEffect())

assembly_format = (
"(`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)"
)

def __init__(
self,
lhs: SSAValue,
rhs: SSAValue,
attributes: dict[str, Attribute] = {},
is_disjoint: UnitAttr | None = None,
):
super().__init__(
operands=[lhs, rhs],
attributes=attributes,
result_types=[lhs.type],
properties={
"isDisjoint": is_disjoint,
},
)


class IntegerConversionOp(IRDLOperation, ABC):
arg = operand_def(IntegerType)

Expand Down Expand Up @@ -674,7 +707,7 @@ class AndOp(ArithmeticBinOperation):


@irdl_op_definition
class OrOp(ArithmeticBinOperation):
class OrOp(ArithmeticBinOpDisjoint):
name = "llvm.or"


Expand Down

0 comments on commit 615cd54

Please sign in to comment.