Skip to content

Commit

Permalink
dialects: (stablehlo) add stablehlo.case
Browse files Browse the repository at this point in the history
  • Loading branch information
Erick Ochoa committed Aug 25, 2024
1 parent c5c77e2 commit 954c46c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
15 changes: 13 additions & 2 deletions tests/filecheck/dialects/stablehlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,16 @@
// CHECK: %and = "stablehlo.and"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%and = "stablehlo.and"(%t0, %t0) : (tensor<i32>, tensor<i32>) -> tensor<i32>

// CHECK: "stablehlo.return"(%t0) : (tensor<i32>) -> ()
"stablehlo.return"(%t0) : (tensor<i32>) -> ()
%index = "test.op"() : () -> tensor<i32>
%result_branch0 = "test.op"() : () -> tensor<2xi64>
%result_branch1 = "test.op"() : () -> tensor<2xi64>

// CHECK: %result0, %result1 = "stablehlo.case"(%index) ({
%0:2 = "stablehlo.case"(%index) ({
// CHECK: "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
// CHECK: "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
// CHECK: }) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
49 changes: 48 additions & 1 deletion xdsl/dialects/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@
import abc
from typing import Annotated, TypeAlias, cast

from xdsl.dialects.builtin import AnyTensorType, DenseArrayBase, IntegerType, TensorType
from xdsl.dialects.builtin import (
I32,
AnyTensorType,
DenseArrayBase,
IntegerType,
TensorType,
)
from xdsl.ir import (
Attribute,
Dialect,
EnumAttribute,
ParametrizedAttribute,
Region,
SpacedOpaqueSyntaxAttribute,
SSAValue,
StrEnum,
Expand All @@ -23,12 +30,15 @@
from xdsl.irdl import (
ConstraintVar,
IRDLOperation,
VarRegion,
attr_def,
irdl_attr_definition,
irdl_op_definition,
operand_def,
result_def,
var_operand_def,
var_region_def,
var_result_def,
)
from xdsl.traits import IsTerminator
from xdsl.utils.exceptions import VerifyException
Expand Down Expand Up @@ -176,6 +186,42 @@ def __init__(
super().__init__(operands=(lhs, rhs), result_types=(result_type,))


# TODO: Change to SI32 once StableHLO adopts signful integer semantics
# See: https://github.com/openxla/stablehlo/issues/22
# https://github.com/openxla/stablehlo/issues/2489
SI32TensorType: TypeAlias = TensorType[I32]


@irdl_op_definition
class CaseOp(IRDLOperation):
"""
Semantics
Produces the output from executing exactly one function from branches depending on the value of index.
More formally, result = selected_branch() where:
selected_branch = branches[index] if 0 <= index < size(branches).
selected_branch = branches[-1] otherwise.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case
"""

name = "stablehlo.case"
index = operand_def(SI32TensorType)
branches: VarRegion = var_region_def("single_block")
_results = var_result_def(AnyTensorType | TokenType)

def __init__(
self,
index: SSAValue,
branches: list[Region],
results: list[AnyTensorType | TokenType],
):
super().__init__(
operands=(index,), result_types=(results,), regions=(branches,)
)


@irdl_op_definition
class MultiplyOp(ElementwiseBinaryOperation):
"""
Expand Down Expand Up @@ -294,6 +340,7 @@ def verify_(self) -> None:
AbsOp,
AddOp,
AndOp,
CaseOp,
MultiplyOp,
ReturnOp,
SubtractOp,
Expand Down

0 comments on commit 954c46c

Please sign in to comment.