Skip to content

Commit

Permalink
dialects: (stablehlo) add stablehlo.after_all
Browse files Browse the repository at this point in the history
  • Loading branch information
Erick Ochoa committed Aug 27, 2024
1 parent 1e488d7 commit 8f04b1b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/filecheck/dialects/stablehlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
// CHECK: %add = "stablehlo.add"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%add = "stablehlo.add"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>

%token0 = "test.op"() : () -> !stablehlo.token
%token1 = "test.op"() : () -> !stablehlo.token
// CHECK: %after_all = "stablehlo.after_all"(%token0, %token1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token
%after_all = "stablehlo.after_all"(%token0, %token1) : (!stablehlo.token, !stablehlo.token) -> !stablehlo.token

// CHECK: %multiply = "stablehlo.multiply"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%multiply = "stablehlo.multiply"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>

Expand Down
18 changes: 18 additions & 0 deletions xdsl/dialects/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,23 @@ class AddOp(ElementwiseBinaryOperation):
name = "stablehlo.add"


@irdl_op_definition
class AfterAllOp(IRDLOperation):
"""
Ensures that the operations producing the inputs are executed before any operations that depend on result.
Execution of this operation does nothing, it only exists to establish data dependencies from result to inputs.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#after_all
"""

name = "stablehlo.after_all"
inputs = var_operand_def(TokenType)
result = result_def(TokenType)

def __init__(self, inputs: SSAValue):
super().__init__(operands=[inputs], result_types=(TokenType(),))


IntegerTensorType: TypeAlias = TensorType[IntegerType]


Expand Down Expand Up @@ -318,6 +335,7 @@ def verify_(self) -> None:
[
AbsOp,
AddOp,
AfterAllOp,
AndOp,
BitcastConvertOp,
MultiplyOp,
Expand Down

0 comments on commit 8f04b1b

Please sign in to comment.