Skip to content

[mlir][tensor] Add a PadOp::FoldReifiedShape canonicalization #145732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3791,6 +3791,47 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
}
};

struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> {
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
if (padOp.getNofold()) {
return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
}

ReifiedRankedShapedTypeDims reifiedResultShapes;
if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes)))
return failure();

SmallVector<int64_t> newShape;
for (const auto &[s, ofr] : llvm::zip_equal(
padOp.getResultType().getShape(), reifiedResultShapes.front())) {
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
// Reification does not add static information, just use existing shape.
if (!maybeCst.has_value()) {
newShape.push_back(s);
continue;
}
int64_t cst = *maybeCst;
assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!");
newShape.push_back(cst);
}
if (newShape == padOp.getResultType().getShape())
return failure();

Type oldType = padOp.getResultType();
Type newType =
RankedTensorType::Builder(padOp.getResultType()).setShape(newShape);
Location loc = padOp->getLoc();
Operation *newPad = rewriter.clone(*padOp);
newPad->getResult(0).setType(newType);
rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType,
newPad->getResult(0));
return success();
}
};

} // namespace

LogicalResult
Expand Down Expand Up @@ -3820,7 +3861,7 @@ void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
FoldOrthogonalPaddings, FoldStaticPadding,
FoldConsecutiveConstantPadding>(context);
FoldConsecutiveConstantPadding, FoldReifiedShape>(context);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please dont add as canonicalization. I think we decided that we could just used dim op folding patterns as and when needed instead of running it as a canonicalization without control. I think this pattern is unnecessary really.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken -I might be missing something, the test added by Nico is not covered by any transforms/passes/canonicalizers upstream.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, --linalg-fold-unit-extent-dims --canonicalize without the folder is getting there. So I might be missing another transform.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok no, --linalg-fold-unit-extent-dims --canonicalize doesn't work in all cases. I can't find a transform to simplify without Nico's canonicalization pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try --resolve-ranked-shaped-type-result-dims .

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also tried that one and it doesn't do the trick.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tomorrow I'll take a look if I can augment the pass and get the desired behavior because I agree this doesn't really fit as a canonicalization. The main thing we want from this transform is that in some circumstances when we tile we can get static shapes after applying this, this is shown in the test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MaheshRavishankar is there any written discussion about that decision? I'd like to better understand the needs here, and why we keep adding specialized passes, and hopefully figure out a more general upstream solution.

}

/// Return the padding value of the PadOp if it constant. In this context,
Expand Down
18 changes: 18 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2561,3 +2561,21 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
// CHECK: return %[[RES]]

// -----

// CHECK-LABEL: func.func @pad_reification
func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
-> tensor<1x?x64xf32> {
%pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
%es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>

// CHECK: tensor.pad
// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
%padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
^bb0(%a: index, %b: index, %c: index):
tensor.yield %cst : f32
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>

return %padded : tensor<1x?x64xf32>
}
Loading