Skip to content

Commit

Permalink
[DC] Add + re-enable canonicalization patterns (#7952)
Browse files Browse the repository at this point in the history
Various canonicalization patterns were implemented as fold methods. However, these had been turned off a while ago, due to an MLIR bug.

This PR moves the fold methods to actual canonicalization patterns, adds a new one (join on branch), and re-enables previously disabled tests.

Co-authored-by: Morten Borup Petersen <mpetersen@microsoft.com>
  • Loading branch information
mortbopet and Morten Borup Petersen authored Dec 9, 2024
1 parent d2a3117 commit 6d894b9
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 42 deletions.
1 change: 1 addition & 0 deletions include/circt/Dialect/DC/DCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def JoinOp : DCOp<"join", [Commutative]> {

let assemblyFormat = "$tokens attr-dict";
let hasFolder = 1;
let hasCanonicalizer = 1;
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "mlir::ValueRange":$ins), [{
Expand Down
135 changes: 107 additions & 28 deletions lib/Dialect/DC/DCOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,46 +38,125 @@ OpFoldResult JoinOp::fold(FoldAdaptor adaptor) {
if (auto tokens = getTokens(); tokens.size() == 1)
return tokens.front();

// These folders are disabled to work around MLIR bugs when changing
// the number of operands. https://github.com/llvm/llvm-project/issues/64280
return {};
}

struct JoinOnBranchPattern : public OpRewritePattern<JoinOp> {
using OpRewritePattern<JoinOp>::OpRewritePattern;
LogicalResult matchAndRewrite(JoinOp op,
PatternRewriter &rewriter) const override {

// Remove operands which originate from a dc.source op (redundant).
auto *op = getOperation();
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
if (auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
op->eraseOperand(operand.getOperandNumber());
return getOutput();
struct BranchOperandInfo {
// Unique operands from the branch op, in case we have the same operand
// from the branch op multiple times.
SetVector<Value> uniqueOperands;
// Indices which the operands are at in the join op.
BitVector indices;
};

DenseMap<BranchOp, BranchOperandInfo> branchOperands;
for (auto &opOperand : op->getOpOperands()) {
auto branch = opOperand.get().getDefiningOp<BranchOp>();
if (!branch)
continue;

BranchOperandInfo &info = branchOperands[branch];
info.uniqueOperands.insert(opOperand.get());
info.indices.resize(op->getNumOperands());
info.indices.set(opOperand.getOperandNumber());
}

if (branchOperands.empty())
return failure();

// Do we have both operands from any given branch op?
for (auto &it : branchOperands) {
auto branch = it.first;
auto &operandInfo = it.second;
if (operandInfo.uniqueOperands.size() != 2) {
// We don't have both operands from the branch op.
continue;
}

// We have both operands from the branch op. Replace the join op with the
// branch op's data operand.

// Unpack the !dc.value<i1> input to the branch op
auto unpacked =
rewriter.create<UnpackOp>(op.getLoc(), branch.getCondition());
rewriter.modifyOpInPlace(op, [&]() {
op->eraseOperands(operandInfo.indices);
op.getTokensMutable().append({unpacked.getToken()});
});

// Only attempt a single branch at a time - else we'd have to maintain
// OpOperand indices during the loop... too complicated, let recursive
// pattern application handle this.
return success();
}

return failure();
}
};
struct StaggeredJoinCanonicalization : public OpRewritePattern<JoinOp> {
using OpRewritePattern<JoinOp>::OpRewritePattern;
LogicalResult matchAndRewrite(JoinOp op,
PatternRewriter &rewriter) const override {
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
if (!otherJoin) {
// Operand does not originate from a join so it's a valid join input.
continue;
}

// Remove duplicate operands.
llvm::DenseSet<Value> uniqueOperands;
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
if (!uniqueOperands.insert(operand.get()).second) {
op->eraseOperand(operand.getOperandNumber());
return getOutput();
// Operand originates from a join. Erase the current join operand and
// add all of the otherJoin op's inputs to this join.
// DCE will take care of otherJoin in case it's no longer used.
rewriter.modifyOpInPlace(op, [&]() {
op.getTokensMutable().erase(operand.getOperandNumber());
op.getTokensMutable().append(otherJoin.getTokens());
});
return success();
}
return failure();
}
};

// Canonicalization staggered joins where the sink join contains inputs also
// found in the source join.
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
auto otherJoin = operand.get().getDefiningOp<dc::JoinOp>();
if (!otherJoin) {
// Operand does not originate from a join so it's a valid join input.
continue;
struct RemoveJoinOnSourcePattern : public OpRewritePattern<JoinOp> {
using OpRewritePattern<JoinOp>::OpRewritePattern;
LogicalResult matchAndRewrite(JoinOp op,
PatternRewriter &rewriter) const override {
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
if (auto source = operand.get().getDefiningOp<dc::SourceOp>()) {
rewriter.modifyOpInPlace(
op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
return success();
}
}
return failure();
}
};

// Operand originates from a join. Erase the current join operand and add
// all of the otherJoin op's inputs to this join.
// DCE will take care of otherJoin in case it's no longer used.
op->eraseOperand(operand.getOperandNumber());
op->insertOperands(getNumOperands(), otherJoin.getTokens());
return getOutput();
struct RemoveDuplicateJoinOperandsPattern : public OpRewritePattern<JoinOp> {
using OpRewritePattern<JoinOp>::OpRewritePattern;
LogicalResult matchAndRewrite(JoinOp op,
PatternRewriter &rewriter) const override {
llvm::DenseSet<Value> uniqueOperands;
for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) {
if (!uniqueOperands.insert(operand.get()).second) {
rewriter.modifyOpInPlace(
op, [&]() { op->eraseOperand(operand.getOperandNumber()); });
return success();
}
}
return failure();
}
};

return {};
void JoinOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<RemoveDuplicateJoinOperandsPattern, RemoveJoinOnSourcePattern,
StaggeredJoinCanonicalization, JoinOnBranchPattern>(context);
}

// =============================================================================
Expand Down
40 changes: 26 additions & 14 deletions test/Dialect/DC/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ func.func @staggeredJoin1(%a: !dc.token, %b : !dc.token) -> (!dc.token) {
return %1 : !dc.token
}

// TODO: For some reason, the canonicalizer no longer combines the two joins. Investigate.
// CHECK-LABEL: func.func @staggeredJoin2(
// CHECK-SAME: %[[VAL_0:.*]]: !dc.token, %[[VAL_1:.*]]: !dc.token, %[[VAL_2:.*]]: !dc.token, %[[VAL_3:.*]]: !dc.token) -> !dc.token {
// CHECKx: %[[VAL_4:.*]] = dc.join %[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]
// CHECKx: return %[[VAL_4]] : !dc.token
// CHECKx: }
// CHECK: %[[VAL_4:.*]] = dc.join %[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]]
// CHECK: return %[[VAL_4]] : !dc.token
// CHECK: }
func.func @staggeredJoin2(%a: !dc.token, %b : !dc.token, %c : !dc.token, %d : !dc.token) -> (!dc.token) {
%0 = dc.join %a, %b
%1 = dc.join %c, %0, %d
Expand Down Expand Up @@ -102,13 +101,11 @@ func.func @forkToFork2(%a: !dc.token) -> (!dc.token, !dc.token, !dc.token) {
return %0, %2, %3 : !dc.token, !dc.token, !dc.token
}

// TODO: For some reason, the canonicalizer no longer simplifies this redundant
// triangle pattern. Investigate.
// CHECK-LABEL: func.func @merge(
// CHECK-SAME: %[[VAL_0:.*]]: !dc.value<i1>) -> !dc.token {
// CHECKx: %[[VAL_1:.*]], %[[VAL_2:.*]] = dc.unpack %[[VAL_0]] : !dc.value<i1>
// CHECKx: return %[[VAL_1]] : !dc.token
// CHECKx: }
// CHECK: %[[VAL_1:.*]], %[[VAL_2:.*]] = dc.unpack %[[VAL_0]] : !dc.value<i1>
// CHECK: return %[[VAL_1]] : !dc.token
// CHECK: }
func.func @merge(%sel : !dc.value<i1>) -> (!dc.token) {
// Canonicalize away a merge that is fed by a branch with the same select
// input.
Expand All @@ -117,20 +114,35 @@ func.func @merge(%sel : !dc.value<i1>) -> (!dc.token) {
return %0 : !dc.token
}

// TODO: For some reason, the canonicalizer no longer removes the source->join
// pattern. Investigate.
// CHECK-LABEL: func.func @joinOnSource(
// CHECK-SAME: %[[VAL_0:.*]]: !dc.token,
// CHECK-SAME: %[[VAL_1:.*]]: !dc.token) -> !dc.token {
// CHECKx: %[[VAL_2:.*]] = dc.join %[[VAL_0]], %[[VAL_1]]
// CHECKx: return %[[VAL_2]] : !dc.token
// CHECKx: }
// CHECK: %[[VAL_2:.*]] = dc.join %[[VAL_0]], %[[VAL_1]]
// CHECK: return %[[VAL_2]] : !dc.token
// CHECK: }
func.func @joinOnSource(%a : !dc.token, %b : !dc.token) -> (!dc.token) {
%0 = dc.source
%out = dc.join %a, %0, %b
return %out : !dc.token
}


// Join on branch, where all branch results are used in the join is a no-op,
// and the join can use the token of the input value to the branch.
// CHECK-LABEL: func.func @joinOnBranch(
// CHECK-SAME: %[[VAL_0:.*]]: !dc.value<i1>, %[[VAL_1:.*]]: !dc.value<i1>, %[[VAL_2:.*]]: !dc.token) -> (!dc.token, !dc.token) {
// CHECK: %[[VAL_3:.*]], %[[VAL_4:.*]] = dc.branch %[[VAL_1]]
// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = dc.unpack %[[VAL_0]] : !dc.value<i1>
// CHECK: %[[VAL_7:.*]] = dc.join %[[VAL_2]], %[[VAL_3]], %[[VAL_5]]
// CHECK: return %[[VAL_7]], %[[VAL_4]] : !dc.token, !dc.token
// CHECK: }
func.func @joinOnBranch(%sel : !dc.value<i1>, %sel2 : !dc.value<i1>, %other : !dc.token) -> (!dc.token, !dc.token) {
%true, %false = dc.branch %sel
%true2, %false2 = dc.branch %sel2
%out = dc.join %true, %false, %other, %true2
return %out, %false2 : !dc.token, !dc.token
}

// CHECK-LABEL: func.func @forkOfSource() -> (!dc.token, !dc.token) {
// CHECK: %[[VAL_0:.*]] = dc.source
// CHECK: return %[[VAL_0]], %[[VAL_0]] : !dc.token, !dc.token
Expand Down

0 comments on commit 6d894b9

Please sign in to comment.