From 79accaad62c11113432d6e875c3531736579752c Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Sun, 1 Sep 2024 08:58:24 +0100 Subject: [PATCH] dialects: (scf) Add parse/print to `scf.if` --- tests/filecheck/dialects/scf/scf_ops.mlir | 33 +++---- .../stencil/oec-kernels/fvtp2d_qi.mlir | 60 ++++++------ tests/filecheck/frontend/dialects/scf.py | 43 +++------ .../with-mlir/control_flow_hoist_collab.mlir | 20 ++-- .../convert-stencil-to-ll-mlir.mlir | 12 +-- tests/filecheck/transforms/cse.mlir | 66 ++++++------- .../transforms/distribute-stencil.mlir | 12 +-- .../transforms/function-constant-pinning.mlir | 92 ++++++++----------- .../transforms/stencil-bufferize.mlir | 12 +-- xdsl/dialects/scf.py | 71 +++++++++++++- 10 files changed, 233 insertions(+), 188 deletions(-) diff --git a/tests/filecheck/dialects/scf/scf_ops.mlir b/tests/filecheck/dialects/scf/scf_ops.mlir index 3698f7bedb..309007d355 100644 --- a/tests/filecheck/dialects/scf/scf_ops.mlir +++ b/tests/filecheck/dialects/scf/scf_ops.mlir @@ -5,40 +5,41 @@ builtin.module { %0 = "test.op"() : () -> i1 - "scf.if"(%0) ({ + scf.if %0 { %1 = "test.op"() : () -> i32 scf.yield - }, { + } else { %2 = "test.op"() : () -> i32 - scf.yield - }) : (i1) -> () + } // CHECK: %{{.*}} = "test.op"() : () -> i1 - // CHECK-NEXT: "scf.if"(%{{.*}}) ({ + // CHECK-NEXT: scf.if %{{.*}} { // CHECK-NEXT: %{{.*}} = "test.op"() : () -> i32 - // CHECK-NEXT: scf.yield - // CHECK-NEXT: }, { + // CHECK-NEXT: } else { // CHECK-NEXT: %{{.*}} = "test.op"() : () -> i32 - // CHECK-NEXT: scf.yield - // CHECK-NEXT: }) : (i1) -> () + // CHECK-NEXT: } - %3 = "scf.if"(%0) ({ + %3 = scf.if %0 -> (i32) { %4 = "test.op"() : () -> i32 scf.yield %4 : i32 - }, { + } else { %5 = "test.op"() : () -> i32 scf.yield %5 : i32 - }) : (i1) -> i32 - + } - // CHECK: %{{.*}} = "scf.if"(%{{.*}}) ({ + // CHECK: %{{.*}} = scf.if %{{.*}} -> (i32) { // CHECK-NEXT: %{{.*}} = "test.op"() : () -> i32 // CHECK-NEXT: scf.yield %{{.*}} : i32 - // CHECK-NEXT: }, { + // CHECK-NEXT: } else { // CHECK-NEXT: %{{.*}} = "test.op"() : () -> i32 // CHECK-NEXT: scf.yield %{{.*}} : i32 - // CHECK-NEXT: }) : (i1) -> i32 + // CHECK-NEXT: } + + scf.if %0 {} + + // CHECK: scf.if %{{.*}} { + // CHECK-NEXT: } func.func @while() { %init = arith.constant 0 : i32 diff --git a/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir b/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir index 054a4d2e38..a65fc0f060 100644 --- a/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir +++ b/tests/filecheck/dialects/stencil/oec-kernels/fvtp2d_qi.mlir @@ -63,7 +63,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field %24 = arith.cmpf ogt, %23, %cst : f64 - %25 = "scf.if"(%24) ({ + %25 = scf.if %24 -> (f64) { %29 = stencil.access %arg10 [0, -1, 0] : !stencil.temp %30 = stencil.access %arg11 [0, -1, 0] : !stencil.temp %31 = arith.mulf %23, %30 : f64 @@ -71,7 +71,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field %30 = stencil.access %arg11 [0, 0, 0] : !stencil.temp %31 = arith.mulf %23, %30 : f64 @@ -79,17 +79,17 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field (f64) + } %26 = arith.mulf %25, %22 : f64 - %27 = "scf.if"(%24) ({ + %27 = scf.if %24 -> (f64) { %29 = stencil.access %arg7 [0, -1, 0] : !stencil.temp %30 = arith.addf %29, %26 : f64 scf.yield %30 : f64 - }, { + } else { %29 = stencil.access %arg7 [0, 0, 0] : !stencil.temp %30 = arith.addf %29, %26 : f64 scf.yield %30 : f64 - }) : (i1) -> (f64) + } %28 = stencil.store_result %27 : !stencil.result stencil.return %28 : !stencil.result } @@ -178,7 +178,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field // CHECK-NEXT: %25 = arith.cmpf ogt, %24, %cst : f64 -// CHECK-NEXT: %26 = "scf.if"(%25) ({ +// CHECK-NEXT: %26 = scf.if %25 -> (f64) { // CHECK-NEXT: %27 = stencil.access %arg10[0, -1, 0] : !stencil.temp // CHECK-NEXT: %28 = stencil.access %arg11[0, -1, 0] : !stencil.temp // CHECK-NEXT: %29 = arith.mulf %24, %28 : f64 @@ -186,7 +186,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field // CHECK-NEXT: %34 = stencil.access %arg11[0, 0, 0] : !stencil.temp // CHECK-NEXT: %35 = arith.mulf %24, %34 : f64 @@ -194,17 +194,17 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field f64 +// CHECK-NEXT: } // CHECK-NEXT: %39 = arith.mulf %26, %23 : f64 -// CHECK-NEXT: %40 = "scf.if"(%25) ({ +// CHECK-NEXT: %40 = scf.if %25 -> (f64) { // CHECK-NEXT: %41 = stencil.access %arg7[0, -1, 0] : !stencil.temp // CHECK-NEXT: %42 = arith.addf %41, %39 : f64 // CHECK-NEXT: scf.yield %42 : f64 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: %43 = stencil.access %arg7[0, 0, 0] : !stencil.temp // CHECK-NEXT: %44 = arith.addf %43, %39 : f64 // CHECK-NEXT: scf.yield %44 : f64 -// CHECK-NEXT: }) : (i1) -> f64 +// CHECK-NEXT: } // CHECK-NEXT: %45 = stencil.store_result %40 : !stencil.result // CHECK-NEXT: stencil.return %45 : !stencil.result // CHECK-NEXT: } @@ -298,7 +298,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field // SHAPE-NEXT: %30 = arith.cmpf ogt, %29, %cst : f64 -// SHAPE-NEXT: %31 = "scf.if"(%30) ({ +// SHAPE-NEXT: %31 = scf.if %30 -> (f64) { // SHAPE-NEXT: %32 = stencil.access %arg10[0, -1, 0] : !stencil.temp<[0,64]x[-1,65]x[0,64]xf64> // SHAPE-NEXT: %33 = stencil.access %arg11[0, -1, 0] : !stencil.temp<[0,64]x[-1,65]x[0,64]xf64> // SHAPE-NEXT: %34 = arith.mulf %29, %33 : f64 @@ -306,7 +306,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field // SHAPE-NEXT: %39 = stencil.access %arg11[0, 0, 0] : !stencil.temp<[0,64]x[-1,65]x[0,64]xf64> // SHAPE-NEXT: %40 = arith.mulf %29, %39 : f64 @@ -314,17 +314,17 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field f64 +// SHAPE-NEXT: } // SHAPE-NEXT: %44 = arith.mulf %31, %28 : f64 -// SHAPE-NEXT: %45 = "scf.if"(%30) ({ +// SHAPE-NEXT: %45 = scf.if %30 -> (f64) { // SHAPE-NEXT: %46 = stencil.access %arg7[0, -1, 0] : !stencil.temp<[0,64]x[-3,67]x[0,64]xf64> // SHAPE-NEXT: %47 = arith.addf %46, %44 : f64 // SHAPE-NEXT: scf.yield %47 : f64 -// SHAPE-NEXT: }, { +// SHAPE-NEXT: } else { // SHAPE-NEXT: %48 = stencil.access %arg7[0, 0, 0] : !stencil.temp<[0,64]x[-3,67]x[0,64]xf64> // SHAPE-NEXT: %49 = arith.addf %48, %44 : f64 // SHAPE-NEXT: scf.yield %49 : f64 -// SHAPE-NEXT: }) : (i1) -> f64 +// SHAPE-NEXT: } // SHAPE-NEXT: %50 = stencil.store_result %45 : !stencil.result // SHAPE-NEXT: stencil.return %50 : !stencil.result // SHAPE-NEXT: } @@ -462,7 +462,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field> // MLIR-NEXT: %86 = arith.cmpf ogt, %85, %cst_5 : f64 -// MLIR-NEXT: %87 = "scf.if"(%86) ({ +// MLIR-NEXT: %87 = scf.if %86 -> (f64) { // MLIR-NEXT: %88 = arith.constant -1 : index // MLIR-NEXT: %89 = arith.addi %75, %88 : index // MLIR-NEXT: %90 = memref.load %arg10[%74, %89, %76] : memref<64x66x64xf64, strided<[4224, 64, 1], offset: 64>> @@ -474,7 +474,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field> // MLIR-NEXT: %99 = memref.load %arg11[%74, %75, %76] : memref<64x66x64xf64, strided<[4224, 64, 1], offset: 64>> // MLIR-NEXT: %100 = arith.mulf %85, %99 : f64 @@ -482,19 +482,19 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field f64 +// MLIR-NEXT: } // MLIR-NEXT: %104 = arith.mulf %87, %84 : f64 -// MLIR-NEXT: %105 = "scf.if"(%86) ({ +// MLIR-NEXT: %105 = scf.if %86 -> (f64) { // MLIR-NEXT: %106 = arith.constant -1 : index // MLIR-NEXT: %107 = arith.addi %75, %106 : index // MLIR-NEXT: %108 = memref.load %8[%74, %107, %76] : memref<64x70x64xf64, strided<[5184, 72, 1], offset: 21028>> // MLIR-NEXT: %109 = arith.addf %108, %104 : f64 // MLIR-NEXT: scf.yield %109 : f64 -// MLIR-NEXT: }, { +// MLIR-NEXT: } else { // MLIR-NEXT: %110 = memref.load %8[%74, %75, %76] : memref<64x70x64xf64, strided<[5184, 72, 1], offset: 21028>> // MLIR-NEXT: %111 = arith.addf %110, %104 : f64 // MLIR-NEXT: scf.yield %111 : f64 -// MLIR-NEXT: }) : (i1) -> f64 +// MLIR-NEXT: } // MLIR-NEXT: memref.store %105, %arg8_1[%74, %75, %76] : memref<64x65x64xf64, strided<[5184, 72, 1], offset: 21028>> // MLIR-NEXT: scf.yield // MLIR-NEXT: }) : (index, index, index, index, index, index, index, index, index) -> () @@ -610,7 +610,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field // BUFF-NEXT: %20 = arith.cmpf ogt, %19, %cst : f64 -// BUFF-NEXT: %21 = "scf.if"(%20) ({ +// BUFF-NEXT: %21 = scf.if %20 -> (f64) { // BUFF-NEXT: %22 = stencil.access %arg10[0, -1, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> // BUFF-NEXT: %23 = stencil.access %arg11[0, -1, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> // BUFF-NEXT: %24 = arith.mulf %19, %23 : f64 @@ -618,7 +618,7 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field // BUFF-NEXT: %29 = stencil.access %arg11[0, 0, 0] : !stencil.field<[0,64]x[-1,65]x[0,64]xf64> // BUFF-NEXT: %30 = arith.mulf %19, %29 : f64 @@ -626,17 +626,17 @@ func.func @fvtp2d_qi(%arg0: !stencil.field, %arg1: !stencil.field f64 +// BUFF-NEXT: } // BUFF-NEXT: %34 = arith.mulf %21, %18 : f64 -// BUFF-NEXT: %35 = "scf.if"(%20) ({ +// BUFF-NEXT: %35 = scf.if %20 -> (f64) { // BUFF-NEXT: %36 = stencil.access %arg7[0, -1, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // BUFF-NEXT: %37 = arith.addf %36, %34 : f64 // BUFF-NEXT: scf.yield %37 : f64 -// BUFF-NEXT: }, { +// BUFF-NEXT: } else { // BUFF-NEXT: %38 = stencil.access %arg7[0, 0, 0] : !stencil.field<[-4,68]x[-4,68]x[-4,68]xf64> // BUFF-NEXT: %39 = arith.addf %38, %34 : f64 // BUFF-NEXT: scf.yield %39 : f64 -// BUFF-NEXT: }) : (i1) -> f64 +// BUFF-NEXT: } // BUFF-NEXT: %40 = stencil.store_result %35 : !stencil.result // BUFF-NEXT: stencil.return %40 : !stencil.result // BUFF-NEXT: } to <[0, 0, 0], [64, 65, 64]> diff --git a/tests/filecheck/frontend/dialects/scf.py b/tests/filecheck/frontend/dialects/scf.py index 7dd133a1f0..4e45d47acf 100644 --- a/tests/filecheck/frontend/dialects/scf.py +++ b/tests/filecheck/frontend/dialects/scf.py @@ -133,21 +133,18 @@ def test_not_supported_loop_III(start: index, end: index, step: f32): p = FrontendProgram() with CodeContext(p): - # CHECK: %{{.*}} = "scf.if"(%{{.*}}) ({ + # CHECK: %{{.*}} = scf.if %{{.*}} -> (i32) { # CHECK-NEXT: %{{.*}} = "symref.fetch"() {"symbol" = @x} : () -> i32 # CHECK-NEXT: scf.yield %{{.*}} : i32 - # CHECK-NEXT: }, { + # CHECK-NEXT: } else { # CHECK-NEXT: %{{.*}} = "symref.fetch"() {"symbol" = @y} : () -> i32 # CHECK-NEXT: scf.yield %{{.*}} : i32 - # CHECK-NEXT: }) : (i1) -> i32 + # CHECK-NEXT: } def test_if_expr(cond: i1, x: i32, y: i32) -> i32: return x if cond else y - # CHECK: "scf.if"(%{{.*}}) ({ - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }, { - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }) : (i1) -> () + # CHECK: scf.if %{{.*}} { + # CHECK-NEXT: } def test_if_I(cond: i1): if cond: pass @@ -156,23 +153,16 @@ def test_if_I(cond: i1): return # CHECK: %{{.*}} = "symref.fetch"() {"symbol" = @a} : () -> i1 - # CHECK-NEXT: "scf.if"(%{{.*}}) ({ - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }, { + # CHECK-NEXT: scf.if %{{.*}} { + # CHECK-NEXT: } else { # CHECK-NEXT: %{{.*}} = "symref.fetch"() {"symbol" = @b} : () -> i1 - # CHECK-NEXT: "scf.if"(%{{.*}}) ({ - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }, { + # CHECK-NEXT: scf.if %{{.*}} { + # CHECK-NEXT: } else { # CHECK-NEXT: %{{.*}} = "symref.fetch"() {"symbol" = @c} : () -> i1 - # CHECK-NEXT: "scf.if"(%{{.*}}) ({ - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }, { - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }) : (i1) -> () - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }) : (i1) -> () - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }) : (i1) -> () + # CHECK-NEXT: scf.if %{{.*}} { + # CHECK-NEXT: } + # CHECK-NEXT: } + # CHECK-NEXT: } def test_if_II(a: i1, b: i1, c: i1): if a: pass @@ -183,11 +173,8 @@ def test_if_II(a: i1, b: i1, c: i1): return # CHECK: %{{.*}} = "symref.fetch"() {"symbol" = @cond} : () -> i1 - # CHECK-NEXT: "scf.if"(%{{.*}}) ({ - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }, { - # CHECK-NEXT: scf.yield - # CHECK-NEXT: }) : (i1) -> () + # CHECK-NEXT: scf.if %{{.*}} { + # CHECK-NEXT: } def test_if_III(cond: i1): if cond: pass diff --git a/tests/filecheck/mlir-conversion/with-mlir/control_flow_hoist_collab.mlir b/tests/filecheck/mlir-conversion/with-mlir/control_flow_hoist_collab.mlir index 6bfadaf406..c48ced018c 100644 --- a/tests/filecheck/mlir-conversion/with-mlir/control_flow_hoist_collab.mlir +++ b/tests/filecheck/mlir-conversion/with-mlir/control_flow_hoist_collab.mlir @@ -7,17 +7,17 @@ func.func @nested_loop_invariant(%n : index) { %100 = arith.constant 100 : index scf.for %i = %0 to %100 step %1 { %cond = "test.op"() : () -> (i1) - %thing = "scf.if"(%cond) ({ + %thing = scf.if %cond -> (index) { // This is loop invariant // Also is nested in conditional // Contradictory local intuitions :S // MLIR really want to keep operations nested if they only occur on one // branch, which locally makes sense! - %n100 = arith.muli %n, %100 : index + %n100 = arith.muli %n, %100 : index scf.yield %n100 :index - }, { + } else { scf.yield %n : index - }) : (i1) -> index + } "test.op"(%thing) : (index) -> () scf.yield } @@ -32,12 +32,12 @@ func.func @nested_loop_invariant(%n : index) { // WITHOUT-NEXT: %2 = arith.constant 100 : index // WITHOUT-NEXT: scf.for %arg1 = %0 to %2 step %1 { // WITHOUT-NEXT: %3 = "test.op"() : () -> i1 -// WITHOUT-NEXT: %4 = "scf.if"(%3) ({ +// WITHOUT-NEXT: %4 = scf.if %3 -> (index) { // WITHOUT-NEXT: %5 = arith.muli %arg0, %2 : index // WITHOUT-NEXT: scf.yield %5 : index -// WITHOUT-NEXT: }, { +// WITHOUT-NEXT: } else { // WITHOUT-NEXT: scf.yield %arg0 : index -// WITHOUT-NEXT: }) : (i1) -> index +// WITHOUT-NEXT: } // WITHOUT-NEXT: "test.op"(%4) : (index) -> () // WITHOUT-NEXT: } // WITHOUT-NEXT: func.return @@ -54,11 +54,11 @@ func.func @nested_loop_invariant(%n : index) { // WITH-NEXT: %3 = arith.muli %arg0, %2 : index // WITH-NEXT: scf.for %arg1 = %0 to %2 step %1 { // WITH-NEXT: %4 = "test.op"() : () -> i1 -// WITH-NEXT: %5 = "scf.if"(%4) ({ +// WITH-NEXT: %5 = scf.if %4 -> (index) { // WITH-NEXT: scf.yield %3 : index -// WITH-NEXT: }, { +// WITH-NEXT: } else { // WITH-NEXT: scf.yield %arg0 : index -// WITH-NEXT: }) : (i1) -> index +// WITH-NEXT: } // WITH-NEXT: "test.op"(%5) : (index) -> () // WITH-NEXT: } // WITH-NEXT: func.return diff --git a/tests/filecheck/transforms/convert-stencil-to-ll-mlir.mlir b/tests/filecheck/transforms/convert-stencil-to-ll-mlir.mlir index 920fd398b9..5a6bd53c44 100644 --- a/tests/filecheck/transforms/convert-stencil-to-ll-mlir.mlir +++ b/tests/filecheck/transforms/convert-stencil-to-ll-mlir.mlir @@ -622,13 +622,13 @@ func.func @store_result_lowering(%arg0 : f64) { func.func @if_lowering(%arg0_1 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) attributes {"stencil.program"}{ %101, %102 = stencil.apply(%arg1_1 = %arg0_1 : f64) -> (!stencil.temp<[0,7]x[0,7]x[0,7]xf64>, !stencil.temp<[0,7]x[0,7]x[0,7]xf64>) { %true = "test.op"() : () -> i1 - %103, %104 = "scf.if"(%true) ({ + %103, %104 = scf.if %true -> (!stencil.result, f64) { %105 = stencil.store_result %arg1_1 : !stencil.result scf.yield %105, %arg1_1 : !stencil.result, f64 - }, { + } else { %106 = stencil.store_result : !stencil.result scf.yield %106, %arg1_1 : !stencil.result, f64 - }) : (i1) -> (!stencil.result, f64) + } %107 = stencil.store_result %104 : !stencil.result stencil.return %103, %107 : !stencil.result, !stencil.result } @@ -652,13 +652,13 @@ func.func @if_lowering(%arg0_1 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64 // CHECK-NEXT: "scf.parallel"(%0, %1, %2, %6, %7, %8, %3, %4, %5) <{"operandSegmentSizes" = array}> ({ // CHECK-NEXT: ^0(%9 : index, %10 : index, %11 : index): // CHECK-NEXT: %true = "test.op"() : () -> i1 -// CHECK-NEXT: %12, %13 = "scf.if"(%true) ({ +// CHECK-NEXT: %12, %13 = scf.if %true -> (f64, f64) { // CHECK-NEXT: memref.store %arg0, %b0_storeview[%9, %10, %11] : memref<7x7x7xf64, strided<[49, 7, 1]>> // CHECK-NEXT: scf.yield %arg0, %arg0 : f64, f64 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: %14 = builtin.unrealized_conversion_cast to f64 // CHECK-NEXT: scf.yield %14, %arg0 : f64, f64 -// CHECK-NEXT: }) : (i1) -> (f64, f64) +// CHECK-NEXT: } // CHECK-NEXT: memref.store %13, %b1_storeview[%9, %10, %11] : memref<7x7x7xf64, strided<[49, 7, 1]>> // CHECK-NEXT: scf.yield // CHECK-NEXT: }) : (index, index, index, index, index, index, index, index, index) -> () diff --git a/tests/filecheck/transforms/cse.mlir b/tests/filecheck/transforms/cse.mlir index ba0c827754..70470fa4c4 100644 --- a/tests/filecheck/transforms/cse.mlir +++ b/tests/filecheck/transforms/cse.mlir @@ -506,74 +506,74 @@ func.func @failing_issue_59135(%arg0_10 : tensor<2x2xi1>, %arg1_7 : f32, %arg2_8 // CHECK-NEXT: } func.func @cse_multiple_regions(%arg0_11 : i1, %arg1_8 : tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { - %94 = "scf.if"(%arg0_11) ({ + %94 = scf.if %arg0_11 -> (tensor<5xf32>) { %95 = tensor.empty() : tensor<5xf32> scf.yield %95 : tensor<5xf32> - }, { + } else { scf.yield %arg1_8 : tensor<5xf32> - }) : (i1) -> tensor<5xf32> - %96 = "scf.if"(%arg0_11) ({ + } + %96 = scf.if %arg0_11 -> (tensor<5xf32>) { %97 = tensor.empty() : tensor<5xf32> scf.yield %97 : tensor<5xf32> - }, { + } else { scf.yield %arg1_8 : tensor<5xf32> - }) : (i1) -> tensor<5xf32> + } func.return %94, %96 : tensor<5xf32>, tensor<5xf32> } // CHECK: func.func @cse_multiple_regions(%arg0 : i1, %arg1 : tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { -// CHECK-NEXT: %0 = "scf.if"(%arg0) ({ +// CHECK-NEXT: %0 = scf.if %arg0 -> (tensor<5xf32>) { // CHECK-NEXT: %1 = tensor.empty() : tensor<5xf32> // CHECK-NEXT: scf.yield %1 : tensor<5xf32> -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: scf.yield %arg1 : tensor<5xf32> -// CHECK-NEXT: }) : (i1) -> tensor<5xf32> +// CHECK-NEXT: } // CHECK-NEXT: func.return %0, %0 : tensor<5xf32>, tensor<5xf32> // CHECK-NEXT: } // Check that no CSE happens on a recursively side-effecting ops containing side-effects. func.func @no_cse_multiple_regions_side_effect(%arg0_12 : i1, %arg1_9 : memref<5xf32>) -> (memref<5xf32>, memref<5xf32>) { - %90 = "scf.if"(%arg0_12) ({ + %90 = scf.if %arg0_12 -> (memref<5xf32>) { %91 = memref.alloc() : memref<5xf32> scf.yield %91 : memref<5xf32> - }, { + } else { scf.yield %arg1_9 : memref<5xf32> - }) : (i1) -> memref<5xf32> - %92 = "scf.if"(%arg0_12) ({ + } + %92 = scf.if %arg0_12 -> (memref<5xf32>) { %93 = memref.alloc() : memref<5xf32> scf.yield %93 : memref<5xf32> - }, { + } else { scf.yield %arg1_9 : memref<5xf32> - }) : (i1) -> memref<5xf32> + } func.return %90, %92 : memref<5xf32>, memref<5xf32> } // CHECK: func.func @no_cse_multiple_regions_side_effect(%arg0 : i1, %arg1 : memref<5xf32>) -> (memref<5xf32>, memref<5xf32>) { -// CHECK-NEXT: %0 = "scf.if"(%arg0) ({ +// CHECK-NEXT: %0 = scf.if %arg0 -> (memref<5xf32>) { // CHECK-NEXT: %1 = memref.alloc() : memref<5xf32> // CHECK-NEXT: scf.yield %1 : memref<5xf32> -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: scf.yield %arg1 : memref<5xf32> -// CHECK-NEXT: }) : (i1) -> memref<5xf32> -// CHECK-NEXT: %2 = "scf.if"(%arg0) ({ +// CHECK-NEXT: } +// CHECK-NEXT: %2 = scf.if %arg0 -> (memref<5xf32>) { // CHECK-NEXT: %3 = memref.alloc() : memref<5xf32> // CHECK-NEXT: scf.yield %3 : memref<5xf32> -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: scf.yield %arg1 : memref<5xf32> -// CHECK-NEXT: }) : (i1) -> memref<5xf32> +// CHECK-NEXT: } // CHECK-NEXT: func.return %0, %2 : memref<5xf32>, memref<5xf32> // CHECK-NEXT: } func.func @cse_recursive_effects_success() -> (i32, i32, i32) { %98 = "test.op_with_memread"() : () -> i32 %99 = arith.constant true - %100 = "scf.if"(%99) ({ + %100 = scf.if %99 -> (i32) { %101 = arith.constant 42 : i32 scf.yield %101 : i32 - }, { + } else { %102 = arith.constant 24 : i32 scf.yield %102 : i32 - }) : (i1) -> i32 + } %103 = "test.op_with_memread"() : () -> i32 func.return %98, %103, %100 : i32, i32, i32 } @@ -581,13 +581,13 @@ func.func @no_cse_multiple_regions_side_effect(%arg0_12 : i1, %arg1_9 : memref<5 // CHECK: func.func @cse_recursive_effects_success() -> (i32, i32, i32) { // CHECK-NEXT: %0 = "test.op_with_memread"() : () -> i32 // CHECK-NEXT: %1 = arith.constant true -// CHECK-NEXT: %2 = "scf.if"(%1) ({ +// CHECK-NEXT: %2 = scf.if %1 -> (i32) { // CHECK-NEXT: %3 = arith.constant 42 : i32 // CHECK-NEXT: scf.yield %3 : i32 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: %4 = arith.constant 24 : i32 // CHECK-NEXT: scf.yield %4 : i32 -// CHECK-NEXT: }) : (i1) -> i32 +// CHECK-NEXT: } // CHECK-NEXT: func.return %0, %0, %2 : i32, i32, i32 // CHECK-NEXT: } @@ -595,14 +595,14 @@ func.func @no_cse_multiple_regions_side_effect(%arg0_12 : i1, %arg1_9 : memref<5 func.func @cse_recursive_effects_failure() -> (i32, i32, i32) { %104 = "test.op_with_memread"() : () -> i32 %105 = arith.constant true - %106 = "scf.if"(%105) ({ + %106 = scf.if %105 -> (i32) { "test.op_with_memwrite"() : () -> () %107 = arith.constant 42 : i32 scf.yield %107 : i32 - }, { + } else { %108 = arith.constant 24 : i32 scf.yield %108 : i32 - }) : (i1) -> i32 + } %109 = "test.op_with_memread"() : () -> i32 func.return %104, %109, %106 : i32, i32, i32 } @@ -610,14 +610,14 @@ func.func @cse_recursive_effects_failure() -> (i32, i32, i32) { // CHECK: func.func @cse_recursive_effects_failure() -> (i32, i32, i32) { // CHECK-NEXT: %0 = "test.op_with_memread"() : () -> i32 // CHECK-NEXT: %1 = arith.constant true -// CHECK-NEXT: %2 = "scf.if"(%1) ({ +// CHECK-NEXT: %2 = scf.if %1 -> (i32) { // CHECK-NEXT: "test.op_with_memwrite"() : () -> () // CHECK-NEXT: %3 = arith.constant 42 : i32 // CHECK-NEXT: scf.yield %3 : i32 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: %4 = arith.constant 24 : i32 // CHECK-NEXT: scf.yield %4 : i32 -// CHECK-NEXT: }) : (i1) -> i32 +// CHECK-NEXT: } // CHECK-NEXT: %5 = "test.op_with_memread"() : () -> i32 // CHECK-NEXT: func.return %0, %5, %2 : i32, i32, i32 // CHECK-NEXT: } diff --git a/tests/filecheck/transforms/distribute-stencil.mlir b/tests/filecheck/transforms/distribute-stencil.mlir index 9db7ccd7f5..15d270eaae 100644 --- a/tests/filecheck/transforms/distribute-stencil.mlir +++ b/tests/filecheck/transforms/distribute-stencil.mlir @@ -181,13 +181,13 @@ func.func @store_result_lowering(%arg0 : f64) { func.func @if_lowering(%arg0_1 : f64, %b0 : !stencil.field<[0,8]x[0,8]x[0,8]xf64>, %b1 : !stencil.field<[0,8]x[0,8]x[0,8]xf64>) attributes {"stencil.program"}{ %101, %102 = stencil.apply(%arg1_1 = %arg0_1 : f64) -> (!stencil.temp<[0,8]x[0,8]x[0,8]xf64>, !stencil.temp<[0,8]x[0,8]x[0,8]xf64>) { %true = "test.op"() : () -> i1 - %103, %104 = "scf.if"(%true) ({ + %103, %104 = scf.if %true -> (!stencil.result, f64) { %105 = stencil.store_result %arg1_1 : !stencil.result scf.yield %105, %arg1_1 : !stencil.result, f64 - }, { + } else { %106 = stencil.store_result : !stencil.result scf.yield %106, %arg1_1 : !stencil.result, f64 - }) : (i1) -> (!stencil.result, f64) + } %107 = stencil.store_result %104 : !stencil.result stencil.return %103, %107 : !stencil.result, !stencil.result } @@ -199,13 +199,13 @@ func.func @if_lowering(%arg0_1 : f64, %b0 : !stencil.field<[0,8]x[0,8]x[0,8]xf64 // SHAPE: func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,8]x[0,8]x[0,8]xf64>, %b1 : !stencil.field<[0,8]x[0,8]x[0,8]xf64>) attributes {"stencil.program"}{ // SHAPE-NEXT: %0, %1 = stencil.apply(%arg1 = %arg0 : f64) -> (!stencil.temp<[0,8]x[0,8]x[0,8]xf64>, !stencil.temp<[0,8]x[0,8]x[0,8]xf64>) { // SHAPE-NEXT: %true = "test.op"() : () -> i1 -// SHAPE-NEXT: %2, %3 = "scf.if"(%true) ({ +// SHAPE-NEXT: %2, %3 = scf.if %true -> (!stencil.result, f64) { // SHAPE-NEXT: %4 = stencil.store_result %arg1 : !stencil.result // SHAPE-NEXT: scf.yield %4, %arg1 : !stencil.result, f64 -// SHAPE-NEXT: }, { +// SHAPE-NEXT: } else { // SHAPE-NEXT: %5 = stencil.store_result : !stencil.result // SHAPE-NEXT: scf.yield %5, %arg1 : !stencil.result, f64 -// SHAPE-NEXT: }) : (i1) -> (!stencil.result, f64) +// SHAPE-NEXT: } // SHAPE-NEXT: %6 = stencil.store_result %3 : !stencil.result // SHAPE-NEXT: stencil.return %2, %6 : !stencil.result, !stencil.result // SHAPE-NEXT: } diff --git a/tests/filecheck/transforms/function-constant-pinning.mlir b/tests/filecheck/transforms/function-constant-pinning.mlir index 79f4cfabe1..cc0872932f 100644 --- a/tests/filecheck/transforms/function-constant-pinning.mlir +++ b/tests/filecheck/transforms/function-constant-pinning.mlir @@ -13,13 +13,13 @@ func.func @basic() -> i32 { // compare the value to the constant we want to specialize for // CHECK-NEXT: %0 = arith.constant 0 : i64 // CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32 -// CHECK-NEXT: %2 = "scf.if"(%1) ({ +// CHECK-NEXT: %2 = scf.if %1 -> (i32) { // if they are equal, branch to specialized function // CHECK-NEXT: %3 = func.call @basic_pinned() : () -> i32 // CHECK-NEXT: scf.yield %3 : i32 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: scf.yield %v : i32 -// CHECK-NEXT: }) : (i1) -> i32 +// CHECK-NEXT: } // CHECK-NEXT: func.return %2 : i32 // CHECK-NEXT: } // specialized function here @@ -35,12 +35,11 @@ func.func @control_flow() { %cond = "test.op"() {"pin_to_constants"= [true]} : () -> i1 - "scf.if"(%cond) ({ + scf.if %cond { "test.op"() {"inside_if"} : () -> () - scf.yield - }, { - scf.yield - }) : (i1) -> () + } else { + scf.yield + } "test.op"() {"after_op"} : () -> () @@ -52,21 +51,17 @@ func.func @control_flow() { // CHECK-NEXT: %cond = "test.op"() : () -> i1 // CHECK-NEXT: %0 = arith.constant true // CHECK-NEXT: %1 = arith.cmpi eq, %cond, %0 : i1 -// CHECK-NEXT: "scf.if"(%1) ({ +// CHECK-NEXT: scf.if %1 { // CHECK-NEXT: func.call @control_flow_pinned() : () -> () -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // inline the rest of the function inside the else statement of the specialization block // (there is no early return in MLIR) -// CHECK-NEXT: "scf.if"(%cond) ({ +// CHECK-NEXT: scf.if %cond { // CHECK-NEXT: "test.op"() {"inside_if"} : () -> () -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }, { -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }) : (i1) -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: } // CHECK-NEXT: "test.op"() {"after_op"} : () -> () -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }) : (i1) -> () +// CHECK-NEXT: } // CHECK-NEXT: func.return // CHECK-NEXT: } // specialized function does not contain operations that happen before the specialized function @@ -74,12 +69,10 @@ func.func @control_flow() { // CHECK-NEXT: func.func @control_flow_pinned() { // CHECK-NEXT: %cond = arith.constant true // this scf.if can be constant folded by MLIR later on (not done as part of this pass) -// CHECK-NEXT: "scf.if"(%cond) ({ +// CHECK-NEXT: scf.if %cond { // CHECK-NEXT: "test.op"() {"inside_if"} : () -> () -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }, { -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }) : (i1) -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: } // CHECK-NEXT: "test.op"() {"after_op"} : () -> () // CHECK-NEXT: func.return // CHECK-NEXT: } @@ -98,15 +91,15 @@ func.func @function_args(%arg0: memref<100xf32>) -> i32 { // CHECK-NEXT: %v = "test.op"() : () -> i32 // CHECK-NEXT: %0 = arith.constant 0 : i64 // CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32 -// CHECK-NEXT: %2 = "scf.if"(%1) ({ +// CHECK-NEXT: %2 = scf.if %1 -> (i32) { // make sure that we forward function args to the specialized function // and weave return values through the generated if/else // CHECK-NEXT: %3 = func.call @function_args_pinned(%arg0) : (memref<100xf32>) -> i32 // CHECK-NEXT: scf.yield %3 : i32 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: "test.op"(%v, %arg0) : (i32, memref<100xf32>) -> () // CHECK-NEXT: scf.yield %v : i32 -// CHECK-NEXT: }) : (i1) -> i32 +// CHECK-NEXT: } // CHECK-NEXT: func.return %2 : i32 // CHECK-NEXT: } // CHECK-NEXT: func.func @function_args_pinned(%arg0 : memref<100xf32>) -> i32 { @@ -122,12 +115,11 @@ func.func @function_args(%arg0: memref<100xf32>) -> i32 { func.func @control_flow_and_function_args(%arg: i32) -> i32 { %cond = "test.op"() {"pin_to_constants"= [true]} : () -> i1 - "scf.if"(%cond) ({ + scf.if %cond { "test.op"() {"inside_if"} : () -> () + } else { scf.yield - }, { - scf.yield - }) : (i1) -> () + } %rval = "test.op"(%arg) {"after_op"} : (i32) -> i32 @@ -138,29 +130,25 @@ func.func @control_flow_and_function_args(%arg: i32) -> i32 { // CHECK-NEXT: %cond = "test.op"() : () -> i1 // CHECK-NEXT: %0 = arith.constant true // CHECK-NEXT: %1 = arith.cmpi eq, %cond, %0 : i1 -// CHECK-NEXT: %2 = "scf.if"(%1) ({ +// CHECK-NEXT: %2 = scf.if %1 -> (i32) { // CHECK-NEXT: %3 = func.call @control_flow_and_function_args_pinned(%arg) : (i32) -> i32 // CHECK-NEXT: scf.yield %3 : i32 -// CHECK-NEXT: }, { -// CHECK-NEXT: "scf.if"(%cond) ({ +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.if %cond { // CHECK-NEXT: "test.op"() {"inside_if"} : () -> () -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }, { -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }) : (i1) -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: } // CHECK-NEXT: %rval = "test.op"(%arg) {"after_op"} : (i32) -> i32 // CHECK-NEXT: scf.yield %rval : i32 -// CHECK-NEXT: }) : (i1) -> i32 +// CHECK-NEXT: } // CHECK-NEXT: func.return %2 : i32 // CHECK-NEXT: } // CHECK-NEXT: func.func @control_flow_and_function_args_pinned(%arg : i32) -> i32 { // CHECK-NEXT: %cond = arith.constant true -// CHECK-NEXT: "scf.if"(%cond) ({ +// CHECK-NEXT: scf.if %cond { // CHECK-NEXT: "test.op"() {"inside_if"} : () -> () -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }, { -// CHECK-NEXT: scf.yield -// CHECK-NEXT: }) : (i1) -> () +// CHECK-NEXT: } else { +// CHECK-NEXT: } // CHECK-NEXT: %rval = "test.op"(%arg) {"after_op"} : (i32) -> i32 // CHECK-NEXT: func.return %rval : i32 // CHECK-NEXT: } @@ -178,20 +166,20 @@ func.func @specialize_multi_case() -> i32 { // CHECK-NEXT: %v = "test.op"() : () -> i32 // CHECK-NEXT: %0 = arith.constant 0 : i64 // CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32 -// CHECK-NEXT: %2 = "scf.if"(%1) ({ +// CHECK-NEXT: %2 = scf.if %1 -> (i32) { // CHECK-NEXT: %3 = func.call @specialize_multi_case_pinned_1() : () -> i32 // CHECK-NEXT: scf.yield %3 : i32 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: %4 = arith.constant 1 : i64 // CHECK-NEXT: %5 = arith.cmpi eq, %v, %4 : i32 -// CHECK-NEXT: %6 = "scf.if"(%5) ({ +// CHECK-NEXT: %6 = scf.if %5 -> (i32) { // CHECK-NEXT: %7 = func.call @specialize_multi_case_pinned() : () -> i32 // CHECK-NEXT: scf.yield %7 : i32 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: scf.yield %v : i32 -// CHECK-NEXT: }) : (i1) -> i32 +// CHECK-NEXT: } // CHECK-NEXT: scf.yield %6 : i32 -// CHECK-NEXT: }) : (i1) -> i32 +// CHECK-NEXT: } // CHECK-NEXT: func.return %2 : i32 // CHECK-NEXT: } // CHECK-NEXT: func.func @specialize_multi_case_pinned_1() -> i32 { @@ -200,12 +188,12 @@ func.func @specialize_multi_case() -> i32 { // CHECK-NEXT: %v = arith.constant 0 : i64 // CHECK-NEXT: %0 = arith.constant 1 : i64 // CHECK-NEXT: %1 = arith.cmpi eq, %v, %0 : i32 -// CHECK-NEXT: %2 = "scf.if"(%1) ({ +// CHECK-NEXT: %2 = scf.if %1 -> (i32) { // CHECK-NEXT: %3 = func.call @specialize_multi_case_pinned() : () -> i32 // CHECK-NEXT: scf.yield %3 : i32 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: scf.yield %v : i32 -// CHECK-NEXT: }) : (i1) -> i32 +// CHECK-NEXT: } // CHECK-NEXT: func.return %2 : i32 // CHECK-NEXT: } // CHECK-NEXT: func.func @specialize_multi_case_pinned() -> i32 { diff --git a/tests/filecheck/transforms/stencil-bufferize.mlir b/tests/filecheck/transforms/stencil-bufferize.mlir index 1b3fb35e02..cd12d27b57 100644 --- a/tests/filecheck/transforms/stencil-bufferize.mlir +++ b/tests/filecheck/transforms/stencil-bufferize.mlir @@ -344,13 +344,13 @@ func.func @stencil_init_index_offset(%0 : !stencil.field<[0,64]x[0,64]x[0,64]xin func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) attributes {"stencil.program"}{ %0, %1 = stencil.apply(%arg1 = %arg0 : f64) -> (!stencil.temp<[0,7]x[0,7]x[0,7]xf64>, !stencil.temp<[0,7]x[0,7]x[0,7]xf64>) { %true = "test.pureop"() : () -> i1 - %2, %3 = "scf.if"(%true) ({ + %2, %3 = scf.if %true -> (!stencil.result, f64) { %4 = stencil.store_result %arg1 : !stencil.result scf.yield %4, %arg1 : !stencil.result, f64 - }, { + } else { %5 = stencil.store_result : !stencil.result scf.yield %5, %arg1 : !stencil.result, f64 - }) : (i1) -> (!stencil.result, f64) + } %6 = stencil.store_result %3 : !stencil.result stencil.return %2, %6 : !stencil.result, !stencil.result } @@ -362,13 +362,13 @@ func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, // CHECK: func.func @if_lowering(%arg0 : f64, %b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) attributes {"stencil.program"}{ // CHECK-NEXT: stencil.apply(%arg1 = %arg0 : f64) outs (%b0 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>, %b1 : !stencil.field<[0,7]x[0,7]x[0,7]xf64>) { // CHECK-NEXT: %true = "test.pureop"() : () -> i1 -// CHECK-NEXT: %0, %1 = "scf.if"(%true) ({ +// CHECK-NEXT: %0, %1 = scf.if %true -> (!stencil.result, f64) { // CHECK-NEXT: %2 = stencil.store_result %arg1 : !stencil.result // CHECK-NEXT: scf.yield %2, %arg1 : !stencil.result, f64 -// CHECK-NEXT: }, { +// CHECK-NEXT: } else { // CHECK-NEXT: %3 = stencil.store_result : !stencil.result // CHECK-NEXT: scf.yield %3, %arg1 : !stencil.result, f64 -// CHECK-NEXT: }) : (i1) -> (!stencil.result, f64) +// CHECK-NEXT: } // CHECK-NEXT: %4 = stencil.store_result %1 : !stencil.result // CHECK-NEXT: stencil.return %0, %4 : !stencil.result, !stencil.result // CHECK-NEXT: } to <[0, 0, 0], [7, 7, 7]> diff --git a/xdsl/dialects/scf.py b/xdsl/dialects/scf.py index f947b586d7..96df4d24aa 100644 --- a/xdsl/dialects/scf.py +++ b/xdsl/dialects/scf.py @@ -170,7 +170,7 @@ class If(IRDLOperation): output: VarOpResult = var_result_def(AnyAttr()) cond: Operand = operand_def(IntegerType(1)) - true_region: Region = region_def() + true_region: Region = region_def("single_block") # TODO this should be optional under certain conditions false_region: Region = region_def() @@ -188,6 +188,7 @@ def __init__( return_types: Sequence[Attribute], true_region: Region | Sequence[Block] | Sequence[Operation], false_region: Region | Sequence[Block] | Sequence[Operation] | None = None, + attr_dict: dict[str, Attribute] | None = None, ): if false_region is None: false_region = Region() @@ -196,8 +197,76 @@ def __init__( operands=[cond], result_types=[return_types], regions=[true_region, false_region], + attributes=attr_dict, ) + @staticmethod + def parse_region_with_yield(parser: Parser) -> Region: + region = parser.parse_region() + block = region.blocks.last + if block is None: + block = Block() + region.add_block(block) + last_op = block.last_op + if last_op is not None and last_op.has_trait(IsTerminator): + return region + + block.add_op(Yield()) + + return region + + @classmethod + def parse(cls, parser: Parser) -> Self: + cond = parser.parse_operand() + return_types = [] + if parser.parse_optional_punctuation("->"): + return_types = parser.parse_comma_separated_list( + parser.Delimiter.PAREN, parser.parse_type + ) + else: + return_types = [] + + then_region = cls.parse_region_with_yield(parser) + + else_region = ( + cls.parse_region_with_yield(parser) + if parser.parse_optional_keyword("else") + else Region() + ) + + attr_dict = parser.parse_optional_attr_dict() + + return cls(cond, return_types, then_region, else_region, attr_dict) + + def print(self, printer: Printer): + printer.print_string(" ") + printer.print_operand(self.cond) + + print_block_terminators = False + if bool(self.output): + printer.print_string(" -> (") + printer.print_list(self.output.types, printer.print_attribute) + printer.print_string(")") + print_block_terminators = True + + printer.print_string(" ") + printer.print_region( + self.true_region, + print_entry_block_args=False, + print_block_terminators=print_block_terminators, + ) + + if bool(self.false_region.blocks): + printer.print_string(" else ") + printer.print_region( + self.false_region, + print_entry_block_args=False, + print_block_terminators=print_block_terminators, + ) + + if bool(self.attributes.keys): + printer.print_attr_dict(self.attributes) + class ForOpHasCanonicalizationPatternsTrait(HasCanonicalizationPatternsTrait): @classmethod