diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp index d7d1865bc56ba..1902757e83bf3 100644 --- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp @@ -87,13 +87,52 @@ struct DoLoopConversion : public OpRewritePattern { return success(); } }; + +void copyBlockAndTransformResult(PatternRewriter &rewriter, Block &srcBlock, + Block &dstBlock) { + Operation *srcTerminator = srcBlock.getTerminator(); + auto resultOp = cast(srcTerminator); + + dstBlock.getOperations().splice(dstBlock.begin(), srcBlock.getOperations(), + srcBlock.begin(), std::prev(srcBlock.end())); + + if (!resultOp->getOperands().empty()) { + rewriter.setInsertionPointToEnd(&dstBlock); + scf::YieldOp::create(rewriter, resultOp->getLoc(), resultOp->getOperands()); + } + + rewriter.eraseOp(srcTerminator); +} + +struct IfConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(fir::IfOp ifOp, + PatternRewriter &rewriter) const override { + bool hasElse = !ifOp.getElseRegion().empty(); + auto scfIfOp = + scf::IfOp::create(rewriter, ifOp.getLoc(), ifOp.getResultTypes(), + ifOp.getCondition(), hasElse); + + copyBlockAndTransformResult(rewriter, ifOp.getThenRegion().front(), + scfIfOp.getThenRegion().front()); + + if (hasElse) { + copyBlockAndTransformResult(rewriter, ifOp.getElseRegion().front(), + scfIfOp.getElseRegion().front()); + } + + scfIfOp->setAttrs(ifOp->getAttrs()); + rewriter.replaceOp(ifOp, scfIfOp); + return success(); + } +}; } // namespace void FIRToSCFPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); ConversionTarget target(getContext()); - target.addIllegalOp(); + target.addIllegalOp(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/flang/test/Fir/FirToSCF/if.fir b/flang/test/Fir/FirToSCF/if.fir new file mode 100644 index 0000000000000..03be264c4cdf5 --- /dev/null +++ b/flang/test/Fir/FirToSCF/if.fir @@ -0,0 +1,57 @@ +// RUN: fir-opt %s --fir-to-scf | FileCheck %s + +// CHECK-LABEL: func.func @test_only( +// CHECK-SAME: %[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32) { +// CHECK: scf.if %[[ARG0]] { +// CHECK: %[[VAL_1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : i32 +// CHECK: } +// CHECK: return +// CHECK: } +func.func @test_only(%arg0 : i1, %arg1 : i32) { + fir.if %arg0 { + %0 = arith.addi %arg1, %arg1 : i32 + } + return +} + +// CHECK-LABEL: func.func @test_else() { +// CHECK: %[[VAL_1:.*]] = arith.constant false +// CHECK: %[[VAL_2:.*]] = arith.constant 2 : i32 +// CHECK: scf.if %[[VAL_1]] { +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32 +// CHECK: } else { +// CHECK: %[[VAL_3:.*]] = arith.constant 3 : i32 +// CHECK: } +// CHECK: return +// CHECK: } +func.func @test_else() { + %false = arith.constant false + %1 = arith.constant 2 : i32 + fir.if %false { + %2 = arith.constant 3 : i32 + } else { + %3 = arith.constant 3 : i32 + } + return +} + +// CHECK-LABEL: func.func @test_two_result() { +// CHECK: %[[VAL_1:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = arith.constant false +// CHECK: %[[RES:[0-9]+]]:2 = scf.if %[[VAL_2]] -> (f32, f32) { +// CHECK: scf.yield %[[VAL_1]], %[[VAL_1]] : f32, f32 +// CHECK: } else { +// CHECK: scf.yield %[[VAL_1]], %[[VAL_1]] : f32, f32 +// CHECK: } +// CHECK: return +// CHECK: } +func.func @test_two_result() { + %1 = arith.constant 2.0 : f32 + %cmp = arith.constant false + %x, %y = fir.if %cmp -> (f32, f32) { + fir.result %1, %1 : f32, f32 + } else { + fir.result %1, %1 : f32, f32 + } + return +}