diff --git a/tests/filecheck/dialects/stablehlo/ops.mlir b/tests/filecheck/dialects/stablehlo/ops.mlir index eddd4c0be4..6e4feaa266 100644 --- a/tests/filecheck/dialects/stablehlo/ops.mlir +++ b/tests/filecheck/dialects/stablehlo/ops.mlir @@ -31,5 +31,16 @@ // CHECK: %and = "stablehlo.and"(%t0, %t0) : (tensor, tensor) -> tensor %and = "stablehlo.and"(%t0, %t0) : (tensor, tensor) -> tensor -// CHECK: "stablehlo.return"(%t0) : (tensor) -> () -"stablehlo.return"(%t0) : (tensor) -> () +%index = "test.op"() : () -> tensor +%result_branch0 = "test.op"() : () -> tensor<2xi64> +%result_branch1 = "test.op"() : () -> tensor<2xi64> + +// CHECK: %result0, %result1 = "stablehlo.case"(%index) ({ +%0:2 = "stablehlo.case"(%index) ({ + // CHECK: "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> () + "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> () +}, { + // CHECK: "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> () + "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> () +// CHECK: }) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) +}) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) diff --git a/xdsl/dialects/stablehlo.py b/xdsl/dialects/stablehlo.py index a88926f14d..65255f8e12 100644 --- a/xdsl/dialects/stablehlo.py +++ b/xdsl/dialects/stablehlo.py @@ -9,12 +9,19 @@ import abc from typing import Annotated, TypeAlias, cast -from xdsl.dialects.builtin import AnyTensorType, DenseArrayBase, IntegerType, TensorType +from xdsl.dialects.builtin import ( + I32, + AnyTensorType, + DenseArrayBase, + IntegerType, + TensorType, +) from xdsl.ir import ( Attribute, Dialect, EnumAttribute, ParametrizedAttribute, + Region, SpacedOpaqueSyntaxAttribute, SSAValue, StrEnum, @@ -23,12 +30,15 @@ from xdsl.irdl import ( ConstraintVar, IRDLOperation, + VarRegion, attr_def, irdl_attr_definition, irdl_op_definition, operand_def, result_def, var_operand_def, + var_region_def, + var_result_def, ) from xdsl.traits import IsTerminator from xdsl.utils.exceptions import VerifyException @@ -176,6 +186,42 @@ def __init__( super().__init__(operands=(lhs, rhs), result_types=(result_type,)) +# TODO: Change to SI32 once StableHLO adopts signful integer semantics +# See: https://github.com/openxla/stablehlo/issues/22 +# https://github.com/openxla/stablehlo/issues/2489 +SI32TensorType: TypeAlias = TensorType[I32] + + +@irdl_op_definition +class CaseOp(IRDLOperation): + """ + Semantics + + Produces the output from executing exactly one function from branches depending on the value of index. + More formally, result = selected_branch() where: + + selected_branch = branches[index] if 0 <= index < size(branches). + selected_branch = branches[-1] otherwise. + + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case + """ + + name = "stablehlo.case" + index = operand_def(SI32TensorType) + branches: VarRegion = var_region_def("single_block") + _results = var_result_def(AnyTensorType | TokenType) + + def __init__( + self, + index: SSAValue, + branches: list[Region], + results: list[AnyTensorType | TokenType], + ): + super().__init__( + operands=(index,), result_types=(results,), regions=(branches,) + ) + + @irdl_op_definition class MultiplyOp(ElementwiseBinaryOperation): """ @@ -294,6 +340,7 @@ def verify_(self) -> None: AbsOp, AddOp, AndOp, + CaseOp, MultiplyOp, ReturnOp, SubtractOp,