diff --git a/tests/filecheck/dialects/scf/index_switch.mlir b/tests/filecheck/dialects/scf/index_switch.mlir new file mode 100644 index 0000000000..78c1199545 --- /dev/null +++ b/tests/filecheck/dialects/scf/index_switch.mlir @@ -0,0 +1,48 @@ +// RUN: xdsl-opt %s --split-input-file --verify-diagnostics | filecheck %s + +"builtin.module"() ({ + %0 = "arith.constant"() <{value = 0 : index}> : () -> index + "scf.index_switch"(%0) <{cases = array}> ({ + // CHECK: case values should have type i64 + "scf.yield"() : () -> () + }, { + "scf.yield"() : () -> () + }) : (index) -> () +}) : () -> () + +// ----- + +"builtin.module"() ({ + %0 = "arith.constant"() <{value = 0 : index}> : () -> index + "scf.index_switch"(%0) <{cases = array}> ({ + // CHECK: has 1 case regions but 2 case values + "scf.yield"() : () -> () + }, { + "scf.yield"() : () -> () + }) : (index) -> () +}) : () -> () + +// ----- + +"builtin.module"() ({ + %0 = "arith.constant"() <{value = 0 : index}> : () -> index + "scf.index_switch"(%0) <{cases = array}> ({ + // CHECK: 'scf.index_switch' terminates with operation test.termop instead of scf.yield + "test.termop"() : () -> () + }, { + "scf.yield"() : () -> () + }) : (index) -> () +}) : () -> () + +// ----- +"builtin.module"() ({ + %0 = "arith.constant"() <{value = 0 : index}> : () -> index + "scf.index_switch"(%0) <{cases = array}> ({ + %1 = "arith.constant"() <{value = 0 : i64}> : () -> i64 + "scf.yield"(%1) : (i64) -> () + }, { + %2 = "arith.constant"() <{value = 0 : i32}> : () -> i32 + "scf.yield"(%2) : (i32) -> () + // CHECK: region 0 returns values of types (i32) but expected (i64) + }) : (index) -> (i64) +}) : () -> () diff --git a/tests/filecheck/dialects/scf/scf_ops.mlir b/tests/filecheck/dialects/scf/scf_ops.mlir index b5e7420ca5..f9b06b3e53 100644 --- a/tests/filecheck/dialects/scf/scf_ops.mlir +++ b/tests/filecheck/dialects/scf/scf_ops.mlir @@ -180,4 +180,46 @@ builtin.module { // CHECK-NEXT: } // CHECK-NEXT: func.return // CHECK-NEXT: } + + func.func @index_switch(%flag: index) -> i32 { + %a = arith.constant 0 : i32 + %b = arith.constant 1 : i32 + %c, %d = scf.index_switch %flag -> i32, i32 + case 1 { + scf.yield %a, %a : i32, i32 + } + default { + scf.yield %b, %b : i32, i32 + } + func.return %c : i32 + } + + // CHECK: func.func @index_switch(%flag : index) -> i32 { + // CHECK-NEXT: %a = arith.constant 0 : i32 + // CHECK-NEXT: %b = arith.constant 1 : i32 + // CHECK-NEXT: %c, %d = scf.index_switch %flag -> i32, i32 + // CHECK-NEXT: case 1 { + // CHECK-NEXT: scf.yield %a, %a : i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: default { + // CHECK-NEXT: scf.yield %b, %b : i32, i32 + // CHECK-NEXT: } + // CHECK-NEXT: func.return %c : i32 + // CHECK-NEXT: } + + func.func @switch_trivial(%flag: index) { + scf.index_switch %flag + default { + scf.yield + } + func.return + } + + // CHECK: func.func @switch_trivial(%flag : index) { + // CHECK-NEXT: scf.index_switch %flag + // CHECK-NEXT: default { + // CHECK-NEXT: scf.yield + // CHECK-NEXT: } + // CHECK-NEXT: func.return + // CHECK-NEXT: } } diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index e3f6945556..831ec71aed 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -7,8 +7,10 @@ from xdsl.dialects.builtin import ( AnySignlessIntegerOrIndexType, + DenseArrayBase, IndexType, IntegerType, + i64, ) from xdsl.dialects.utils import ( AbstractYieldOperation, @@ -17,15 +19,16 @@ ) from xdsl.ir import Attribute, Block, Dialect, Operation, Region, SSAValue from xdsl.irdl import ( - AnyAttr, AttrSizedOperandSegments, ConstraintVar, IRDLOperation, irdl_op_definition, operand_def, + prop_def, region_def, traits_def, var_operand_def, + var_region_def, var_result_def, ) from xdsl.parser import Parser, UnresolvedOperand @@ -47,9 +50,9 @@ @irdl_op_definition class While(IRDLOperation): name = "scf.while" - arguments = var_operand_def(AnyAttr()) + arguments = var_operand_def() - res = var_result_def(AnyAttr()) + res = var_result_def() before_region = region_def() after_region = region_def() @@ -156,7 +159,11 @@ class Yield(AbstractYieldOperation[Attribute]): traits = traits_def( lambda: frozenset( - [IsTerminator(), HasParent(For, If, ParallelOp, While), Pure()] + [ + IsTerminator(), + HasParent(For, If, ParallelOp, While, IndexSwitchOp), + Pure(), + ] ) ) @@ -164,7 +171,7 @@ class Yield(AbstractYieldOperation[Attribute]): @irdl_op_definition class If(IRDLOperation): name = "scf.if" - output = var_result_def(AnyAttr()) + output = var_result_def() cond = operand_def(IntegerType(1)) true_region = region_def("single_block") @@ -283,9 +290,9 @@ class For(IRDLOperation): ub = operand_def(T) step = operand_def(T) - iter_args = var_operand_def(AnyAttr()) + iter_args = var_operand_def() - res = var_result_def(AnyAttr()) + res = var_result_def() body = region_def("single_block") @@ -452,8 +459,8 @@ class ParallelOp(IRDLOperation): lowerBound = var_operand_def(IndexType) upperBound = var_operand_def(IndexType) step = var_operand_def(IndexType) - initVals = var_operand_def(AnyAttr()) - res = var_result_def(AnyAttr()) + initVals = var_operand_def() + res = var_result_def() body = region_def("single_block") @@ -576,7 +583,7 @@ def get_arg_type_of_nth_reduction_op(self, index: int): @irdl_op_definition class ReduceOp(IRDLOperation): name = "scf.reduce" - argument = operand_def(AnyAttr()) + argument = operand_def() body = region_def("single_block") @@ -626,7 +633,7 @@ def verify_(self) -> None: @irdl_op_definition class ReduceReturnOp(IRDLOperation): name = "scf.reduce.return" - result = operand_def(AnyAttr()) + result = operand_def() traits = frozenset([HasParent(ReduceOp), IsTerminator(), Pure()]) @@ -638,7 +645,7 @@ def __init__(self, result: SSAValue | Operation): class Condition(IRDLOperation): name = "scf.condition" cond = operand_def(IntegerType(1)) - arguments = var_operand_def(AnyAttr()) + arguments = var_operand_def() traits = frozenset([HasParent(While), IsTerminator(), Pure()]) @@ -689,6 +696,110 @@ def parse(cls, parser: Parser) -> Self: return op +@irdl_op_definition +class IndexSwitchOp(IRDLOperation): + name = "scf.index_switch" + + arg = operand_def(IndexType) + cases = prop_def(DenseArrayBase) + + output = var_result_def() + + default_region = region_def("single_block") + case_regions = var_region_def("single_block") + + traits = frozenset([RecursiveMemoryEffect(), SingleBlockImplicitTerminator(Yield)]) + + def __init__( + self, + arg: Operation | SSAValue, + cases: DenseArrayBase, + default_region: Region, + case_regions: Sequence[Region], + result_types: Sequence[Attribute], + attr_dict: dict[str, Attribute] | None = None, + ): + properties = { + "cases": cases, + } + + super().__init__( + operands=(arg,), + attributes=attr_dict, + properties=properties, + regions=(default_region, case_regions), + result_types=(result_types,), + ) + + def verify_(self) -> None: + if self.cases.elt_type != i64: + raise VerifyException("case values should have type i64") + + if len(self.cases.data) != len(self.case_regions): + raise VerifyException( + f"has {len(self.case_regions)} case regions but {len(self.cases.data)} case values" + ) + + cases = self.cases.data.data + if len(set(cases)) != len(cases): + raise VerifyException("has duplicate case value") + + def verify_region(region: Region, name: str): + yield_op = region.block.last_op + assert isinstance(yield_op, Yield) + + if yield_op.operand_types != self.result_types: + raise VerifyException( + f'region {name} returns values of types ({", ".join(str(x) for x in yield_op.operand_types)})' + f' but expected ({", ".join(str(x) for x in self.result_types)})' + ) + + verify_region(self.default_region, "default") + for name, region in zip(cases, self.case_regions): + verify_region(region, str(name.data)) + + def print(self, printer: Printer): + printer.print_string(" ") + printer.print_operand(self.arg) + attr_dict = {k: v for k, v in self.attributes.items() if k != "cases"} + if attr_dict: + printer.print_string(" ") + printer.print_attr_dict(attr_dict) + if self.result_types: + printer.print_string(" -> ") + printer.print_list(self.result_types, printer.print_attribute) + printer.print_string("\n") + for case_value, case_region in zip(self.cases.data.data, self.case_regions): + printer.print_string(f"case {case_value.data} ") + printer.print_region(case_region) + printer.print_string("\n") + + printer.print_string("default ") + printer.print_region(self.default_region) + + @classmethod + def parse(cls, parser: Parser) -> Self: + arg = parser.parse_operand() + attr_dict = parser.parse_optional_attr_dict() + result_types: list[Attribute] = [] + if parser.parse_optional_punctuation("->"): + types = parser.parse_optional_undelimited_comma_separated_list( + parser.parse_optional_type, parser.parse_type + ) + if types is None: + parser.raise_error("result types not found") + result_types = types + case_values: list[int] = [] + case_regions: list[Region] = [] + while parser.parse_optional_keyword("case"): + case_values.append(parser.parse_integer()) + case_regions.append(parser.parse_region()) + cases = DenseArrayBase.from_list(i64, case_values) + parser.parse_keyword("default") + default_region = parser.parse_region() + return cls(arg, cases, default_region, case_regions, result_types, attr_dict) + + Scf = Dialect( "scf", [ @@ -700,6 +811,7 @@ def parse(cls, parser: Parser) -> Self: ReduceOp, ReduceReturnOp, While, + IndexSwitchOp, ], [], )