diff --git a/include/circt/Dialect/DC/DCOps.td b/include/circt/Dialect/DC/DCOps.td index aeca9f2d1ab1..aef37afa6637 100644 --- a/include/circt/Dialect/DC/DCOps.td +++ b/include/circt/Dialect/DC/DCOps.td @@ -94,6 +94,7 @@ def JoinOp : DCOp<"join", [Commutative]> { let assemblyFormat = "$tokens attr-dict"; let hasFolder = 1; + let hasCanonicalizer = 1; let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "mlir::ValueRange":$ins), [{ diff --git a/lib/Dialect/DC/DCOps.cpp b/lib/Dialect/DC/DCOps.cpp index 880a2438f3cf..00c48c18cd60 100644 --- a/lib/Dialect/DC/DCOps.cpp +++ b/lib/Dialect/DC/DCOps.cpp @@ -38,46 +38,125 @@ OpFoldResult JoinOp::fold(FoldAdaptor adaptor) { if (auto tokens = getTokens(); tokens.size() == 1) return tokens.front(); - // These folders are disabled to work around MLIR bugs when changing - // the number of operands. https://github.com/llvm/llvm-project/issues/64280 return {}; +} + +struct JoinOnBranchPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(JoinOp op, + PatternRewriter &rewriter) const override { - // Remove operands which originate from a dc.source op (redundant). - auto *op = getOperation(); - for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) { - if (auto source = operand.get().getDefiningOp()) { - op->eraseOperand(operand.getOperandNumber()); - return getOutput(); + struct BranchOperandInfo { + // Unique operands from the branch op, in case we have the same operand + // from the branch op multiple times. + SetVector uniqueOperands; + // Indices which the operands are at in the join op. + BitVector indices; + }; + + DenseMap branchOperands; + for (auto &opOperand : op->getOpOperands()) { + auto branch = opOperand.get().getDefiningOp(); + if (!branch) + continue; + + BranchOperandInfo &info = branchOperands[branch]; + info.uniqueOperands.insert(opOperand.get()); + info.indices.resize(op->getNumOperands()); + info.indices.set(opOperand.getOperandNumber()); } + + if (branchOperands.empty()) + return failure(); + + // Do we have both operands from any given branch op? + for (auto &it : branchOperands) { + auto branch = it.first; + auto &operandInfo = it.second; + if (operandInfo.uniqueOperands.size() != 2) { + // We don't have both operands from the branch op. + continue; + } + + // We have both operands from the branch op. Replace the join op with the + // branch op's data operand. + + // Unpack the !dc.value input to the branch op + auto unpacked = + rewriter.create(op.getLoc(), branch.getCondition()); + rewriter.modifyOpInPlace(op, [&]() { + op->eraseOperands(operandInfo.indices); + op.getTokensMutable().append({unpacked.getToken()}); + }); + + // Only attempt a single branch at a time - else we'd have to maintain + // OpOperand indices during the loop... too complicated, let recursive + // pattern application handle this. + return success(); + } + + return failure(); } +}; +struct StaggeredJoinCanonicalization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(JoinOp op, + PatternRewriter &rewriter) const override { + for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) { + auto otherJoin = operand.get().getDefiningOp(); + if (!otherJoin) { + // Operand does not originate from a join so it's a valid join input. + continue; + } - // Remove duplicate operands. - llvm::DenseSet uniqueOperands; - for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) { - if (!uniqueOperands.insert(operand.get()).second) { - op->eraseOperand(operand.getOperandNumber()); - return getOutput(); + // Operand originates from a join. Erase the current join operand and + // add all of the otherJoin op's inputs to this join. + // DCE will take care of otherJoin in case it's no longer used. + rewriter.modifyOpInPlace(op, [&]() { + op.getTokensMutable().erase(operand.getOperandNumber()); + op.getTokensMutable().append(otherJoin.getTokens()); + }); + return success(); } + return failure(); } +}; - // Canonicalization staggered joins where the sink join contains inputs also - // found in the source join. - for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) { - auto otherJoin = operand.get().getDefiningOp(); - if (!otherJoin) { - // Operand does not originate from a join so it's a valid join input. - continue; +struct RemoveJoinOnSourcePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(JoinOp op, + PatternRewriter &rewriter) const override { + for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) { + if (auto source = operand.get().getDefiningOp()) { + rewriter.modifyOpInPlace( + op, [&]() { op->eraseOperand(operand.getOperandNumber()); }); + return success(); + } } + return failure(); + } +}; - // Operand originates from a join. Erase the current join operand and add - // all of the otherJoin op's inputs to this join. - // DCE will take care of otherJoin in case it's no longer used. - op->eraseOperand(operand.getOperandNumber()); - op->insertOperands(getNumOperands(), otherJoin.getTokens()); - return getOutput(); +struct RemoveDuplicateJoinOperandsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(JoinOp op, + PatternRewriter &rewriter) const override { + llvm::DenseSet uniqueOperands; + for (OpOperand &operand : llvm::make_early_inc_range(op->getOpOperands())) { + if (!uniqueOperands.insert(operand.get()).second) { + rewriter.modifyOpInPlace( + op, [&]() { op->eraseOperand(operand.getOperandNumber()); }); + return success(); + } + } + return failure(); } +}; - return {}; +void JoinOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.insert(context); } // ============================================================================= diff --git a/test/Dialect/DC/canonicalization.mlir b/test/Dialect/DC/canonicalization.mlir index c99dd7da7e77..baf316e8a6ce 100644 --- a/test/Dialect/DC/canonicalization.mlir +++ b/test/Dialect/DC/canonicalization.mlir @@ -13,12 +13,11 @@ func.func @staggeredJoin1(%a: !dc.token, %b : !dc.token) -> (!dc.token) { return %1 : !dc.token } -// TODO: For some reason, the canonicalizer no longer combines the two joins. Investigate. // CHECK-LABEL: func.func @staggeredJoin2( // CHECK-SAME: %[[VAL_0:.*]]: !dc.token, %[[VAL_1:.*]]: !dc.token, %[[VAL_2:.*]]: !dc.token, %[[VAL_3:.*]]: !dc.token) -> !dc.token { -// CHECKx: %[[VAL_4:.*]] = dc.join %[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]] -// CHECKx: return %[[VAL_4]] : !dc.token -// CHECKx: } +// CHECK: %[[VAL_4:.*]] = dc.join %[[VAL_2]], %[[VAL_3]], %[[VAL_0]], %[[VAL_1]] +// CHECK: return %[[VAL_4]] : !dc.token +// CHECK: } func.func @staggeredJoin2(%a: !dc.token, %b : !dc.token, %c : !dc.token, %d : !dc.token) -> (!dc.token) { %0 = dc.join %a, %b %1 = dc.join %c, %0, %d @@ -102,13 +101,11 @@ func.func @forkToFork2(%a: !dc.token) -> (!dc.token, !dc.token, !dc.token) { return %0, %2, %3 : !dc.token, !dc.token, !dc.token } -// TODO: For some reason, the canonicalizer no longer simplifies this redundant -// triangle pattern. Investigate. // CHECK-LABEL: func.func @merge( // CHECK-SAME: %[[VAL_0:.*]]: !dc.value) -> !dc.token { -// CHECKx: %[[VAL_1:.*]], %[[VAL_2:.*]] = dc.unpack %[[VAL_0]] : !dc.value -// CHECKx: return %[[VAL_1]] : !dc.token -// CHECKx: } +// CHECK: %[[VAL_1:.*]], %[[VAL_2:.*]] = dc.unpack %[[VAL_0]] : !dc.value +// CHECK: return %[[VAL_1]] : !dc.token +// CHECK: } func.func @merge(%sel : !dc.value) -> (!dc.token) { // Canonicalize away a merge that is fed by a branch with the same select // input. @@ -117,20 +114,35 @@ func.func @merge(%sel : !dc.value) -> (!dc.token) { return %0 : !dc.token } -// TODO: For some reason, the canonicalizer no longer removes the source->join -// pattern. Investigate. // CHECK-LABEL: func.func @joinOnSource( // CHECK-SAME: %[[VAL_0:.*]]: !dc.token, // CHECK-SAME: %[[VAL_1:.*]]: !dc.token) -> !dc.token { -// CHECKx: %[[VAL_2:.*]] = dc.join %[[VAL_0]], %[[VAL_1]] -// CHECKx: return %[[VAL_2]] : !dc.token -// CHECKx: } +// CHECK: %[[VAL_2:.*]] = dc.join %[[VAL_0]], %[[VAL_1]] +// CHECK: return %[[VAL_2]] : !dc.token +// CHECK: } func.func @joinOnSource(%a : !dc.token, %b : !dc.token) -> (!dc.token) { %0 = dc.source %out = dc.join %a, %0, %b return %out : !dc.token } + +// Join on branch, where all branch results are used in the join is a no-op, +// and the join can use the token of the input value to the branch. +// CHECK-LABEL: func.func @joinOnBranch( +// CHECK-SAME: %[[VAL_0:.*]]: !dc.value, %[[VAL_1:.*]]: !dc.value, %[[VAL_2:.*]]: !dc.token) -> (!dc.token, !dc.token) { +// CHECK: %[[VAL_3:.*]], %[[VAL_4:.*]] = dc.branch %[[VAL_1]] +// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = dc.unpack %[[VAL_0]] : !dc.value +// CHECK: %[[VAL_7:.*]] = dc.join %[[VAL_2]], %[[VAL_3]], %[[VAL_5]] +// CHECK: return %[[VAL_7]], %[[VAL_4]] : !dc.token, !dc.token +// CHECK: } +func.func @joinOnBranch(%sel : !dc.value, %sel2 : !dc.value, %other : !dc.token) -> (!dc.token, !dc.token) { + %true, %false = dc.branch %sel + %true2, %false2 = dc.branch %sel2 + %out = dc.join %true, %false, %other, %true2 + return %out, %false2 : !dc.token, !dc.token +} + // CHECK-LABEL: func.func @forkOfSource() -> (!dc.token, !dc.token) { // CHECK: %[[VAL_0:.*]] = dc.source // CHECK: return %[[VAL_0]], %[[VAL_0]] : !dc.token, !dc.token