diff --git a/tests/filecheck/dialects/stablehlo/ops.mlir b/tests/filecheck/dialects/stablehlo/ops.mlir index 08d568ca50..04e02171e4 100644 --- a/tests/filecheck/dialects/stablehlo/ops.mlir +++ b/tests/filecheck/dialects/stablehlo/ops.mlir @@ -34,5 +34,16 @@ // %bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor) -> tensor<2xi16> %bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor) -> tensor<2xi16> -// CHECK: "stablehlo.return"(%t0) : (tensor) -> () -"stablehlo.return"(%t0) : (tensor) -> () +%index = "test.op"() : () -> tensor +%result_branch0 = "test.op"() : () -> tensor<2xi64> +%result_branch1 = "test.op"() : () -> tensor<2xi64> + +// CHECK: %0, %1 = "stablehlo.case"(%index) ({ +%0:2 = "stablehlo.case"(%index) ({ + // CHECK: "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> () + "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> () +}, { + // CHECK: "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> () + "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> () +// CHECK: }) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) +}) : (tensor) -> (tensor<2xi64>, tensor<2xi64>) diff --git a/xdsl/dialects/stablehlo.py b/xdsl/dialects/stablehlo.py index e39ba4872f..edede55cba 100644 --- a/xdsl/dialects/stablehlo.py +++ b/xdsl/dialects/stablehlo.py @@ -7,14 +7,22 @@ """ import abc +from collections.abc import Sequence from typing import Annotated, TypeAlias, cast -from xdsl.dialects.builtin import AnyTensorType, DenseArrayBase, IntegerType, TensorType +from xdsl.dialects.builtin import ( + I32, + AnyTensorType, + DenseArrayBase, + IntegerType, + TensorType, +) from xdsl.ir import ( Attribute, Dialect, EnumAttribute, ParametrizedAttribute, + Region, SpacedOpaqueSyntaxAttribute, SSAValue, StrEnum, @@ -29,6 +37,8 @@ operand_def, result_def, var_operand_def, + var_region_def, + var_result_def, ) from xdsl.traits import IsTerminator from xdsl.utils.exceptions import VerifyException @@ -176,6 +186,42 @@ def __init__( super().__init__(operands=(lhs, rhs), result_types=(result_type,)) +# TODO: Change to SI32 once StableHLO adopts signful integer semantics +# See: https://github.com/openxla/stablehlo/issues/22 +# https://github.com/openxla/stablehlo/issues/2489 +SI32TensorType: TypeAlias = TensorType[I32] + + +@irdl_op_definition +class CaseOp(IRDLOperation): + """ + Semantics + + Produces the output from executing exactly one function from branches depending on the value of index. + More formally, result = selected_branch() where: + + selected_branch = branches[index] if 0 <= index < size(branches). + selected_branch = branches[-1] otherwise. + + https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case + """ + + name = "stablehlo.case" + index = operand_def(SI32TensorType) + branches = var_region_def("single_block") + _results = var_result_def(AnyTensorType | TokenType) + + def __init__( + self, + index: SSAValue, + branches: Sequence[Region], + result_types: Sequence[AnyTensorType | TokenType], + ): + super().__init__( + operands=(index,), result_types=(result_types,), regions=(branches,) + ) + + @irdl_op_definition class BitcastConvertOp(IRDLOperation): """ @@ -320,6 +366,7 @@ def verify_(self) -> None: AddOp, AndOp, BitcastConvertOp, + CaseOp, MultiplyOp, ReturnOp, SubtractOp,