Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (scf) add index_switch #3157

Merged
merged 8 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/filecheck/dialects/scf/index_switch.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// RUN: xdsl-opt %s --split-input-file --verify-diagnostics | filecheck %s

"builtin.module"() ({
%0 = "arith.constant"() <{value = 0 : index}> : () -> index
"scf.index_switch"(%0) <{cases = array<i32: 0>}> ({
// CHECK: case values should have type i64
"scf.yield"() : () -> ()
}, {
"scf.yield"() : () -> ()
}) : (index) -> ()
}) : () -> ()

// -----

"builtin.module"() ({
%0 = "arith.constant"() <{value = 0 : index}> : () -> index
"scf.index_switch"(%0) <{cases = array<i64: 0, 1>}> ({
// CHECK: has 1 case regions but 2 case values
"scf.yield"() : () -> ()
}, {
"scf.yield"() : () -> ()
}) : (index) -> ()
}) : () -> ()

// -----

"builtin.module"() ({
%0 = "arith.constant"() <{value = 0 : index}> : () -> index
"scf.index_switch"(%0) <{cases = array<i64: 0, 1>}> ({
// CHECK: 'scf.index_switch' terminates with operation test.termop instead of scf.yield
"test.termop"() : () -> ()
}, {
"scf.yield"() : () -> ()
}) : (index) -> ()
}) : () -> ()

// -----
"builtin.module"() ({
%0 = "arith.constant"() <{value = 0 : index}> : () -> index
"scf.index_switch"(%0) <{cases = array<i64: 0>}> ({
%1 = "arith.constant"() <{value = 0 : i64}> : () -> i64
"scf.yield"(%1) : (i64) -> ()
}, {
%2 = "arith.constant"() <{value = 0 : i32}> : () -> i32
"scf.yield"(%2) : (i32) -> ()
// CHECK: region 0 returns values of types (i32) but expected (i64)
}) : (index) -> (i64)
}) : () -> ()
42 changes: 42 additions & 0 deletions tests/filecheck/dialects/scf/scf_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,46 @@ builtin.module {
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }

func.func @index_switch(%flag: index) -> i32 {
%a = arith.constant 0 : i32
%b = arith.constant 1 : i32
%c, %d = scf.index_switch %flag -> i32, i32
case 1 {
scf.yield %a, %a : i32, i32
}
default {
scf.yield %b, %b : i32, i32
}
func.return %c : i32
}

// CHECK: func.func @index_switch(%flag : index) -> i32 {
// CHECK-NEXT: %a = arith.constant 0 : i32
// CHECK-NEXT: %b = arith.constant 1 : i32
// CHECK-NEXT: %c, %d = scf.index_switch %flag -> i32, i32
// CHECK-NEXT: case 1 {
// CHECK-NEXT: scf.yield %a, %a : i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: default {
// CHECK-NEXT: scf.yield %b, %b : i32, i32
// CHECK-NEXT: }
// CHECK-NEXT: func.return %c : i32
// CHECK-NEXT: }

func.func @switch_trivial(%flag: index) {
scf.index_switch %flag
default {
scf.yield
}
func.return
}

// CHECK: func.func @switch_trivial(%flag : index) {
// CHECK-NEXT: scf.index_switch %flag
// CHECK-NEXT: default {
// CHECK-NEXT: scf.yield
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }
}
136 changes: 124 additions & 12 deletions xdsl/dialects/scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

from xdsl.dialects.builtin import (
AnySignlessIntegerOrIndexType,
DenseArrayBase,
IndexType,
IntegerType,
i64,
)
from xdsl.dialects.utils import (
AbstractYieldOperation,
Expand All @@ -17,15 +19,16 @@
)
from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue
from xdsl.irdl import (
AnyAttr,
AttrSizedOperandSegments,
ConstraintVar,
IRDLOperation,
irdl_op_definition,
operand_def,
prop_def,
region_def,
traits_def,
var_operand_def,
var_region_def,
var_result_def,
)
from xdsl.parser import Parser, UnresolvedOperand
Expand All @@ -47,9 +50,9 @@
@irdl_op_definition
class While(IRDLOperation):
name = "scf.while"
arguments = var_operand_def(AnyAttr())
arguments = var_operand_def()

res = var_result_def(AnyAttr())
res = var_result_def()
before_region = region_def()
after_region = region_def()

Expand Down Expand Up @@ -156,15 +159,19 @@ class Yield(AbstractYieldOperation[Attribute]):

traits = traits_def(
lambda: frozenset(
[IsTerminator(), HasParent(For, If, ParallelOp, While), Pure()]
[
IsTerminator(),
HasParent(For, If, ParallelOp, While, IndexSwitchOp),
Pure(),
]
)
)


@irdl_op_definition
class If(IRDLOperation):
name = "scf.if"
output = var_result_def(AnyAttr())
output = var_result_def()
cond = operand_def(IntegerType(1))

true_region = region_def("single_block")
Expand Down Expand Up @@ -283,9 +290,9 @@ class For(IRDLOperation):
ub = operand_def(T)
step = operand_def(T)

iter_args = var_operand_def(AnyAttr())
iter_args = var_operand_def()

res = var_result_def(AnyAttr())
res = var_result_def()

body = region_def("single_block")

Expand Down Expand Up @@ -452,8 +459,8 @@ class ParallelOp(IRDLOperation):
lowerBound = var_operand_def(IndexType)
upperBound = var_operand_def(IndexType)
step = var_operand_def(IndexType)
initVals = var_operand_def(AnyAttr())
res = var_result_def(AnyAttr())
initVals = var_operand_def()
res = var_result_def()

body = region_def("single_block")

Expand Down Expand Up @@ -576,7 +583,7 @@ def get_arg_type_of_nth_reduction_op(self, index: int):
@irdl_op_definition
class ReduceOp(IRDLOperation):
name = "scf.reduce"
argument = operand_def(AnyAttr())
argument = operand_def()

body = region_def("single_block")

Expand Down Expand Up @@ -626,7 +633,7 @@ def verify_(self) -> None:
@irdl_op_definition
class ReduceReturnOp(IRDLOperation):
name = "scf.reduce.return"
result = operand_def(AnyAttr())
result = operand_def()

traits = frozenset([HasParent(ReduceOp), IsTerminator(), Pure()])

Expand All @@ -638,7 +645,7 @@ def __init__(self, result: SSAValue | Operation):
class Condition(IRDLOperation):
name = "scf.condition"
cond = operand_def(IntegerType(1))
arguments = var_operand_def(AnyAttr())
arguments = var_operand_def()

traits = frozenset([HasParent(While), IsTerminator(), Pure()])

Expand Down Expand Up @@ -689,6 +696,110 @@ def parse(cls, parser: Parser) -> Self:
return op


@irdl_op_definition
class IndexSwitchOp(IRDLOperation):
name = "scf.index_switch"

arg = operand_def(IndexType)
cases = prop_def(DenseArrayBase)

output = var_result_def()

default_region = region_def("single_block")
case_regions = var_region_def("single_block")

traits = frozenset([RecursiveMemoryEffect(), SingleBlockImplicitTerminator(Yield)])

def __init__(
self,
arg: Operation | SSAValue,
cases: DenseArrayBase,
default_region: Region,
case_regions: Sequence[Region],
result_types: Sequence[Attribute],
attr_dict: dict[str, Attribute] | None = None,
):
properties = {
"cases": cases,
}

super().__init__(
operands=(arg,),
attributes=attr_dict,
properties=properties,
regions=(default_region, case_regions),
result_types=(result_types,),
)

def verify_(self) -> None:
if self.cases.elt_type != i64:
raise VerifyException("case values should have type i64")

if len(self.cases.data) != len(self.case_regions):
raise VerifyException(
f"has {len(self.case_regions)} case regions but {len(self.cases.data)} case values"
)

cases = self.cases.data.data
if len(set(cases)) != len(cases):
raise VerifyException("has duplicate case value")

def verify_region(region: Region, name: str):
yield_op = region.block.last_op
assert isinstance(yield_op, Yield)

if yield_op.operand_types != self.result_types:
raise VerifyException(
f'region {name} returns values of types ({", ".join(str(x) for x in yield_op.operand_types)})'
f' but expected ({", ".join(str(x) for x in self.result_types)})'
)

verify_region(self.default_region, "default")
for name, region in zip(cases, self.case_regions):
verify_region(region, str(name.data))

def print(self, printer: Printer):
printer.print_string(" ")
printer.print_operand(self.arg)
attr_dict = {k: v for k, v in self.attributes.items() if k != "cases"}
if attr_dict:
printer.print_string(" ")
printer.print_attr_dict(attr_dict)
if self.result_types:
printer.print_string(" -> ")
printer.print_list(self.result_types, printer.print_attribute)
printer.print_string("\n")
for case_value, case_region in zip(self.cases.data.data, self.case_regions):
printer.print_string(f"case {case_value.data} ")
printer.print_region(case_region)
printer.print_string("\n")

printer.print_string("default ")
printer.print_region(self.default_region)

@classmethod
def parse(cls, parser: Parser) -> Self:
arg = parser.parse_operand()
attr_dict = parser.parse_optional_attr_dict()
result_types: list[Attribute] = []
if parser.parse_optional_punctuation("->"):
types = parser.parse_optional_undelimited_comma_separated_list(
parser.parse_optional_type, parser.parse_type
)
if types is None:
parser.raise_error("result types not found")
result_types = types
case_values: list[int] = []
case_regions: list[Region] = []
while parser.parse_optional_keyword("case"):
case_values.append(parser.parse_integer())
case_regions.append(parser.parse_region())
cases = DenseArrayBase.from_list(i64, case_values)
parser.parse_keyword("default")
default_region = parser.parse_region()
return cls(arg, cases, default_region, case_regions, result_types, attr_dict)


Scf = Dialect(
"scf",
[
Expand All @@ -700,6 +811,7 @@ def parse(cls, parser: Parser) -> Self:
ReduceOp,
ReduceReturnOp,
While,
IndexSwitchOp,
],
[],
)
Loading