From 8f04b1bccdc974b9c8a2f5e34ae02f37b0a83837 Mon Sep 17 00:00:00 2001 From: Erick Ochoa Date: Tue, 27 Aug 2024 19:41:12 -0400 Subject: [PATCH] dialects: (stablehlo) add stablehlo.after_all --- tests/filecheck/dialects/stablehlo/ops.mlir | 5 +++++ xdsl/dialects/stablehlo.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/tests/filecheck/dialects/stablehlo/ops.mlir b/tests/filecheck/dialects/stablehlo/ops.mlir index 08d568ca50..b06e54f28f 100644 --- a/tests/filecheck/dialects/stablehlo/ops.mlir +++ b/tests/filecheck/dialects/stablehlo/ops.mlir @@ -8,6 +8,11 @@ // CHECK: %add = "stablehlo.add"(%t0, %t0) : (tensor, tensor) -> tensor %add = "stablehlo.add"(%t0, %t0) : (tensor, tensor) -> tensor +%token0 = "test.op"() : () -> !stablehlo.token +%token1 = "test.op"() : () -> !stablehlo.token +// CHECK: %after_all = "stablehlo.after_all"(%token0, %token1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token +%after_all = "stablehlo.after_all"(%token0, %token1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token + // CHECK: %multiply = "stablehlo.multiply"(%t0, %t0) : (tensor, tensor) -> tensor %multiply = "stablehlo.multiply"(%t0, %t0) : (tensor, tensor) -> tensor diff --git a/xdsl/dialects/stablehlo.py b/xdsl/dialects/stablehlo.py index e39ba4872f..4b86ff6787 100644 --- a/xdsl/dialects/stablehlo.py +++ b/xdsl/dialects/stablehlo.py @@ -145,6 +145,23 @@ class AddOp(ElementwiseBinaryOperation): name = "stablehlo.add" +@irdl_op_definition +class AfterAllOp(IRDLOperation): + """ + Ensures that the operations producing the inputs are executed before any operations that depend on result. + Execution of this operation does nothing, it only exists to establish data dependencies from result to inputs. + + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all + """ + + name = "stablehlo.after_all" + inputs = var_operand_def(TokenType) + result = result_def(TokenType) + + def __init__(self, inputs: SSAValue): + super().__init__(operands=[inputs], result_types=(TokenType(),)) + + IntegerTensorType: TypeAlias = TensorType[IntegerType] @@ -318,6 +335,7 @@ def verify_(self) -> None: [ AbsOp, AddOp, + AfterAllOp, AndOp, BitcastConvertOp, MultiplyOp,