Skip to content

Commit

Permalink
dialects: (stablehlo) add stablehlo.case (#3095)
Browse files Browse the repository at this point in the history
Adds stablehlo.case

Co-authored-by: Erick Ochoa <erick@ceci-nest-pas.me>
  • Loading branch information
efferifick and Erick Ochoa authored Aug 28, 2024
1 parent 219b74e commit 69a1d05
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
15 changes: 13 additions & 2 deletions tests/filecheck/dialects/stablehlo/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,16 @@
// %bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor<i32>) -> tensor<2xi16>
%bitcast = "stablehlo.bitcast_convert"(%t0) : (tensor<i32>) -> tensor<2xi16>

// CHECK: "stablehlo.return"(%t0) : (tensor<i32>) -> ()
"stablehlo.return"(%t0) : (tensor<i32>) -> ()
%index = "test.op"() : () -> tensor<i32>
%result_branch0 = "test.op"() : () -> tensor<2xi64>
%result_branch1 = "test.op"() : () -> tensor<2xi64>

// CHECK: %0, %1 = "stablehlo.case"(%index) ({
%0:2 = "stablehlo.case"(%index) ({
// CHECK: "stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
"stablehlo.return"(%result_branch0, %result_branch0) : (tensor<2xi64>, tensor<2xi64>) -> ()
}, {
// CHECK: "stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
"stablehlo.return"(%result_branch1, %result_branch1) : (tensor<2xi64>, tensor<2xi64>) -> ()
// CHECK: }) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
}) : (tensor<i32>) -> (tensor<2xi64>, tensor<2xi64>)
49 changes: 48 additions & 1 deletion xdsl/dialects/stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,22 @@
"""

import abc
from collections.abc import Sequence
from typing import Annotated, TypeAlias, cast

from xdsl.dialects.builtin import AnyTensorType, DenseArrayBase, IntegerType, TensorType
from xdsl.dialects.builtin import (
I32,
AnyTensorType,
DenseArrayBase,
IntegerType,
TensorType,
)
from xdsl.ir import (
Attribute,
Dialect,
EnumAttribute,
ParametrizedAttribute,
Region,
SpacedOpaqueSyntaxAttribute,
SSAValue,
StrEnum,
Expand All @@ -29,6 +37,8 @@
operand_def,
result_def,
var_operand_def,
var_region_def,
var_result_def,
)
from xdsl.traits import IsTerminator
from xdsl.utils.exceptions import VerifyException
Expand Down Expand Up @@ -176,6 +186,42 @@ def __init__(
super().__init__(operands=(lhs, rhs), result_types=(result_type,))


# TODO: Change to SI32 once StableHLO adopts signful integer semantics
# See: https://github.com/openxla/stablehlo/issues/22
# https://github.com/openxla/stablehlo/issues/2489
SI32TensorType: TypeAlias = TensorType[I32]


@irdl_op_definition
class CaseOp(IRDLOperation):
"""
Semantics
Produces the output from executing exactly one function from branches depending on the value of index.
More formally, result = selected_branch() where:
selected_branch = branches[index] if 0 <= index < size(branches).
selected_branch = branches[-1] otherwise.
https://github.com/openxla/stablehlo/blob/main/docs/spec.md#case
"""

name = "stablehlo.case"
index = operand_def(SI32TensorType)
branches = var_region_def("single_block")
_results = var_result_def(AnyTensorType | TokenType)

def __init__(
self,
index: SSAValue,
branches: Sequence[Region],
result_types: Sequence[AnyTensorType | TokenType],
):
super().__init__(
operands=(index,), result_types=(result_types,), regions=(branches,)
)


@irdl_op_definition
class BitcastConvertOp(IRDLOperation):
"""
Expand Down Expand Up @@ -320,6 +366,7 @@ def verify_(self) -> None:
AddOp,
AndOp,
BitcastConvertOp,
CaseOp,
MultiplyOp,
ReturnOp,
SubtractOp,
Expand Down

0 comments on commit 69a1d05

Please sign in to comment.