-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Add a PadOp::FoldReifiedShape canonicalization #145732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3791,6 +3791,47 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> { | |
} | ||
}; | ||
|
||
struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> { | ||
using OpRewritePattern<tensor::PadOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tensor::PadOp padOp, | ||
PatternRewriter &rewriter) const override { | ||
if (padOp.getNofold()) { | ||
return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad"); | ||
} | ||
|
||
ReifiedRankedShapedTypeDims reifiedResultShapes; | ||
if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes))) | ||
return failure(); | ||
nicolasvasilache marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
SmallVector<int64_t> newShape; | ||
for (const auto &[s, ofr] : llvm::zip_equal( | ||
padOp.getResultType().getShape(), reifiedResultShapes.front())) { | ||
std::optional<int64_t> maybeCst = getConstantIntValue(ofr); | ||
// Reification does not add static information, just use existing shape. | ||
if (!maybeCst.has_value()) { | ||
newShape.push_back(s); | ||
continue; | ||
} | ||
int64_t cst = *maybeCst; | ||
assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!"); | ||
newShape.push_back(cst); | ||
} | ||
if (newShape == padOp.getResultType().getShape()) | ||
return failure(); | ||
|
||
Type oldType = padOp.getResultType(); | ||
Type newType = | ||
RankedTensorType::Builder(padOp.getResultType()).setShape(newShape); | ||
Location loc = padOp->getLoc(); | ||
Operation *newPad = rewriter.clone(*padOp); | ||
newPad->getResult(0).setType(newType); | ||
rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType, | ||
newPad->getResult(0)); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
LogicalResult | ||
|
@@ -3820,7 +3861,7 @@ void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, | |
MLIRContext *context) { | ||
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast, | ||
FoldOrthogonalPaddings, FoldStaticPadding, | ||
FoldConsecutiveConstantPadding>(context); | ||
FoldConsecutiveConstantPadding, FoldReifiedShape>(context); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please dont add as canonicalization. I think we decided that we could just used dim op folding patterns as and when needed instead of running it as a canonicalization without control. I think this pattern is unnecessary really. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I'm not mistaken -I might be missing something, the test added by Nico is not covered by any transforms/passes/canonicalizers upstream. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok no, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also tried that one and it doesn't do the trick. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tomorrow I'll take a look if I can augment the pass and get the desired behavior because I agree this doesn't really fit as a canonicalization. The main thing we want from this transform is that in some circumstances when we tile we can get static shapes after applying this, this is shown in the test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @MaheshRavishankar is there any written discussion about that decision? I'd like to better understand the needs here, and why we keep adding specialized passes, and hopefully figure out a more general upstream solution. |
||
} | ||
|
||
/// Return the padding value of the PadOp if it constant. In this context, | ||
|
Uh oh!
There was an error while loading. Please reload this page.