Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/cudaq/Optimizer/Dialect/CC/CCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ def cc_IfOp : CCOp<"if",

let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
let hasCanonicalizer = 1;

let extraClassDeclaration = [{
using RegionBuilderFn = llvm::function_ref<void(mlir::OpBuilder &,
Expand Down
52 changes: 52 additions & 0 deletions lib/Optimizer/Dialect/CC/CCOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1937,6 +1937,58 @@ LogicalResult cudaq::cc::verifyConvergentLinearTypesInRegions(Operation *op) {
return success();
}

namespace {
struct KillRegionIfConstant : public OpRewritePattern<cudaq::cc::IfOp> {
using Base = OpRewritePattern<cudaq::cc::IfOp>;
using Base::Base;

// This rewrite will determine if the condition is constant. If it is, then it
// will elide the true or false region completely, depending on the constant's
// value.
LogicalResult matchAndRewrite(cudaq::cc::IfOp ifOp,
PatternRewriter &rewriter) const override {
auto cond = ifOp.getCondition();
if (!ifOp.getResults().empty())
return failure();
auto con = cond.getDefiningOp<arith::ConstantIntOp>();
if (!con)
return failure();
auto val = con.value();
auto loc = ifOp.getLoc();
auto truth = rewriter.create<arith::ConstantIntOp>(loc, 1, 1);
Region *newRegion = nullptr;
if (val) {
// The else block, if any, is dead.
if (ifOp.getElseRegion().empty())
return failure();
newRegion = &ifOp.getThenRegion();
} else {
// The then block is dead.
newRegion = &ifOp.getElseRegion();
if (newRegion->empty()) {
// If there was no else, then build an empty dummy Region.
OpBuilder::InsertionGuard guard(rewriter);
Block *block = new Block();
rewriter.setInsertionPointToEnd(block);
rewriter.create<cudaq::cc::ContinueOp>(loc);
newRegion->push_back(block);
}
}
rewriter.replaceOpWithNewOp<cudaq::cc::IfOp>(
ifOp, ifOp.getResultTypes(), truth,
[&](OpBuilder &, Location, Region &region) {
region.takeBody(*newRegion);
});
return success();
}
};
} // namespace

void cudaq::cc::IfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<KillRegionIfConstant>(context);
}

//===----------------------------------------------------------------------===//
// CreateLambdaOp
//===----------------------------------------------------------------------===//
Expand Down
69 changes: 69 additions & 0 deletions test/Quake/canonical-3.qke
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,72 @@ func.func @__nvqpp__mlirgen__super() attributes {"cudaq-entrypoint", "cudaq-kern
// CHECK: return
// CHECK: }

func.func private @if_test_call()

func.func @if_test_1() {
%0 = arith.constant true
cc.if (%0) {
func.call @if_test_call() : () -> ()
cc.continue
}
return
}

func.func @if_test_2() {
%0 = arith.constant true
cc.if (%0) {
func.call @if_test_call() : () -> ()
cc.continue
} else {
cc.continue
}
return
}

func.func @if_test_3() {
%0 = arith.constant false
cc.if (%0) {
cc.continue
} else {
func.call @if_test_call() : () -> ()
cc.continue
}
return
}

func.func @if_test_4() {
%0 = arith.constant false
cc.if (%0) {
func.call @if_test_call() : () -> ()
cc.continue
}
return
}

// CHECK-LABEL: func.func @if_test_1() {
// CHECK: %[[VAL_0:.*]] = arith.constant true
// CHECK: cc.if(%[[VAL_0]]) {
// CHECK: func.call @if_test_call() : () -> ()
// CHECK: }
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @if_test_2() {
// CHECK: %[[VAL_0:.*]] = arith.constant true
// CHECK: cc.if(%[[VAL_0]]) {
// CHECK: func.call @if_test_call() : () -> ()
// CHECK: }
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @if_test_3() {
// CHECK: %[[VAL_0:.*]] = arith.constant true
// CHECK: cc.if(%[[VAL_0]]) {
// CHECK: func.call @if_test_call() : () -> ()
// CHECK: }
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @if_test_4() {
// CHECK: return
// CHECK: }
Loading