From 615cd54ef87bdb14ef1d6217c154945d3522d7d9 Mon Sep 17 00:00:00 2001 From: lfrenot Date: Tue, 12 Nov 2024 12:44:09 +0000 Subject: [PATCH] dialects: (llvm) added disjoint flag (#3428) --- tests/dialects/test_llvm.py | 19 +++++++++- tests/filecheck/dialects/llvm/arithmetic.mlir | 3 ++ xdsl/dialects/llvm.py | 35 ++++++++++++++++++- 3 files changed, 55 insertions(+), 2 deletions(-) diff --git a/tests/dialects/test_llvm.py b/tests/dialects/test_llvm.py index 25a4318661..54507fb5cd 100644 --- a/tests/dialects/test_llvm.py +++ b/tests/dialects/test_llvm.py @@ -16,7 +16,6 @@ (llvm.URemOp, {}), (llvm.SRemOp, {}), (llvm.AndOp, {}), - (llvm.OrOp, {}), (llvm.XOrOp, {}), ], ) @@ -73,6 +72,24 @@ def test_llvm_exact_arithmetic_ops( ) +@pytest.mark.parametrize( + "op_type, attributes, disjoint", + [ + (llvm.OrOp, {}, llvm.UnitAttr()), + (llvm.OrOp, {}, None), + ], +) +def test_llvm_disjoint_arithmetic_ops( + op_type: type[llvm.ArithmeticBinOpDisjoint], + attributes: dict[str, Attribute], + disjoint: llvm.UnitAttr | None, +): + op1, op2 = test.TestOp(result_types=[i32, i32]).results + assert op_type(op1, op2, attributes, disjoint).is_structurally_equivalent( + op_type(lhs=op1, rhs=op2, attributes=attributes, is_disjoint=disjoint) + ) + + def test_llvm_pointer_ops(): module = builtin.ModuleOp( [ diff --git a/tests/filecheck/dialects/llvm/arithmetic.mlir b/tests/filecheck/dialects/llvm/arithmetic.mlir index 334927a00b..ada534f4da 100644 --- a/tests/filecheck/dialects/llvm/arithmetic.mlir +++ b/tests/filecheck/dialects/llvm/arithmetic.mlir @@ -44,6 +44,9 @@ %or = llvm.or %arg0, %arg1 : i32 // CHECK: %or = llvm.or %arg0, %arg1 : i32 +%or_disjoint = llvm.or disjoint %arg0, %arg1 : i32 +// CHECK: %or_disjoint = llvm.or disjoint %arg0, %arg1 : i32 + %xor = llvm.xor %arg0, %arg1 : i32 // CHECK: %xor = llvm.xor %arg0, %arg1 : i32 diff --git a/xdsl/dialects/llvm.py b/xdsl/dialects/llvm.py index b43b11ac03..c66021e04c 100644 --- a/xdsl/dialects/llvm.py +++ b/xdsl/dialects/llvm.py @@ -532,6 +532,39 @@ def print(self, printer: Printer) -> None: printer.print(self.lhs.type) +class ArithmeticBinOpDisjoint(IRDLOperation, ABC): + """Class for arithmetic binary operations that use a disjoint flag.""" + + T: ClassVar = VarConstraint("T", BaseAttr(IntegerType)) + + lhs = operand_def(T) + rhs = operand_def(T) + res = result_def(T) + is_disjoint = opt_prop_def(UnitAttr, prop_name="isDisjoint") + + traits = traits_def(NoMemoryEffect()) + + assembly_format = ( + "(`disjoint` $isDisjoint^)? $lhs `,` $rhs attr-dict `:` type($res)" + ) + + def __init__( + self, + lhs: SSAValue, + rhs: SSAValue, + attributes: dict[str, Attribute] = {}, + is_disjoint: UnitAttr | None = None, + ): + super().__init__( + operands=[lhs, rhs], + attributes=attributes, + result_types=[lhs.type], + properties={ + "isDisjoint": is_disjoint, + }, + ) + + class IntegerConversionOp(IRDLOperation, ABC): arg = operand_def(IntegerType) @@ -674,7 +707,7 @@ class AndOp(ArithmeticBinOperation): @irdl_op_definition -class OrOp(ArithmeticBinOperation): +class OrOp(ArithmeticBinOpDisjoint): name = "llvm.or"