-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Make tensor::PadOp a ReifyRankedShapedTypeOpInterface and add a PadOp::FoldReifiedShape canonicalization #145732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
83b14e8
to
59ca35d
Compare
…and add a PadOp::FoldReifiedShape canonicalizatio
59ca35d
to
0862760
Compare
@llvm/pr-subscribers-mlir Author: Nicolas Vasilache (nicolasvasilache) Changes…and add a PadOp::FoldReifiedShape canonicalization Full diff: https://github.com/llvm/llvm-project/pull/145732.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..821384eb7d15a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
def Tensor_PadOp : Tensor_Op<"pad", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
AttrSizedOperandSegments,
Pure,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376fb9973a..c66110f6915e9 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -98,6 +98,9 @@ OpFoldResult getAsOpFoldResult(Value val);
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
/// Convert `arrayAttr` to a vector of OpFoldResult.
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
+// TODO: implement a mixed form of this and deprecate getMixedPadImpl.
+// SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr, ValueRange
+// values);
/// Convert int64_t to integer attributes of index type and return them as
/// OpFoldResult.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 72144ec71c5d2..95468caa87f18 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -3791,13 +3792,78 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
}
};
+struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ if (padOp.getNofold()) {
+ return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
+ }
+
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
+ if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes)))
+ return failure();
+
+ SmallVector<int64_t> newShape;
+ for (const auto &[s, ofr] : llvm::zip_equal(
+ padOp.getResultType().getShape(), reifiedResultShapes.front())) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+ // Reification does not add static information, just use existing shape.
+ if (!maybeCst.has_value()) {
+ newShape.push_back(s);
+ continue;
+ }
+ int64_t cst = *maybeCst;
+ assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!");
+ newShape.push_back(cst);
+ }
+ if (newShape == padOp.getResultType().getShape())
+ return failure();
+
+ Type oldType = padOp.getResultType();
+ Type newType =
+ RankedTensorType::Builder(padOp.getResultType()).setShape(newShape);
+ Location loc = padOp->getLoc();
+ Operation *newPad = rewriter.clone(*padOp);
+ newPad->getResult(0).setType(newType);
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType,
+ newPad->getResult(0));
+ return success();
+ }
+};
+
} // namespace
+LogicalResult
+PadOp::reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+ SmallVector<OpFoldResult> lp = getMixedLowPad();
+ SmallVector<OpFoldResult> hp = getMixedHighPad();
+ for (int64_t i = 0; i < getResultType().getRank(); ++i) {
+ if (!getType().isDynamicDim(i)) {
+ reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
+ continue;
+ }
+ Location loc = getLoc();
+ Value dim = b.createOrFold<tensor::DimOp>(
+ loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
+
+ affine::AffineBuilder ab(b, loc);
+ AffineExpr d0, d1, d2;
+ bindDims(b.getContext(), d0, d1, d2);
+ reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
+ b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
+ }
+ return success();
+}
+
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
FoldOrthogonalPaddings, FoldStaticPadding,
- FoldConsecutiveConstantPadding>(context);
+ FoldConsecutiveConstantPadding, FoldReifiedShape>(context);
}
/// Return the padding value of the PadOp if it constant. In this context,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3251c5a4a2bfd..358c1c214a3b1 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2543,3 +2543,21 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
// CHECK: return %[[RES]]
+
+// -----
+
+// CHECK-LABEL: func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
+ -> tensor<1x?x64xf32> {
+ %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+ %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+// CHECK: tensor.pad
+// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+ %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+ ^bb0(%a: index, %b: index, %c: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+ return %padded : tensor<1x?x64xf32>
+}
|
@llvm/pr-subscribers-mlir-tensor Author: Nicolas Vasilache (nicolasvasilache) Changes…and add a PadOp::FoldReifiedShape canonicalization Full diff: https://github.com/llvm/llvm-project/pull/145732.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 35d0b16628417..821384eb7d15a 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
def Tensor_PadOp : Tensor_Op<"pad", [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+ DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
AttrSizedOperandSegments,
Pure,
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 77c376fb9973a..c66110f6915e9 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -98,6 +98,9 @@ OpFoldResult getAsOpFoldResult(Value val);
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
/// Convert `arrayAttr` to a vector of OpFoldResult.
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
+// TODO: implement a mixed form of this and deprecate getMixedPadImpl.
+// SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr, ValueRange
+// values);
/// Convert int64_t to integer attributes of index type and return them as
/// OpFoldResult.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 72144ec71c5d2..95468caa87f18 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -3791,13 +3792,78 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
}
};
+struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> {
+ using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::PadOp padOp,
+ PatternRewriter &rewriter) const override {
+ if (padOp.getNofold()) {
+ return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
+ }
+
+ ReifiedRankedShapedTypeDims reifiedResultShapes;
+ if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes)))
+ return failure();
+
+ SmallVector<int64_t> newShape;
+ for (const auto &[s, ofr] : llvm::zip_equal(
+ padOp.getResultType().getShape(), reifiedResultShapes.front())) {
+ std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
+ // Reification does not add static information, just use existing shape.
+ if (!maybeCst.has_value()) {
+ newShape.push_back(s);
+ continue;
+ }
+ int64_t cst = *maybeCst;
+ assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!");
+ newShape.push_back(cst);
+ }
+ if (newShape == padOp.getResultType().getShape())
+ return failure();
+
+ Type oldType = padOp.getResultType();
+ Type newType =
+ RankedTensorType::Builder(padOp.getResultType()).setShape(newShape);
+ Location loc = padOp->getLoc();
+ Operation *newPad = rewriter.clone(*padOp);
+ newPad->getResult(0).setType(newType);
+ rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType,
+ newPad->getResult(0));
+ return success();
+ }
+};
+
} // namespace
+LogicalResult
+PadOp::reifyResultShapes(OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
+ reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
+ SmallVector<OpFoldResult> lp = getMixedLowPad();
+ SmallVector<OpFoldResult> hp = getMixedHighPad();
+ for (int64_t i = 0; i < getResultType().getRank(); ++i) {
+ if (!getType().isDynamicDim(i)) {
+ reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
+ continue;
+ }
+ Location loc = getLoc();
+ Value dim = b.createOrFold<tensor::DimOp>(
+ loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
+
+ affine::AffineBuilder ab(b, loc);
+ AffineExpr d0, d1, d2;
+ bindDims(b.getContext(), d0, d1, d2);
+ reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
+ b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
+ }
+ return success();
+}
+
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
FoldOrthogonalPaddings, FoldStaticPadding,
- FoldConsecutiveConstantPadding>(context);
+ FoldConsecutiveConstantPadding, FoldReifiedShape>(context);
}
/// Return the padding value of the PadOp if it constant. In this context,
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 3251c5a4a2bfd..358c1c214a3b1 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -2543,3 +2543,21 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
// CHECK: return %[[RES]]
+
+// -----
+
+// CHECK-LABEL: func.func @pad_reification
+func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
+ -> tensor<1x?x64xf32> {
+ %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
+ %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
+
+// CHECK: tensor.pad
+// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
+ %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
+ ^bb0(%a: index, %b: index, %c: index):
+ tensor.yield %cst : f32
+ } : tensor<1x?x64xf32> to tensor<1x?x64xf32>
+
+ return %padded : tensor<1x?x64xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Pls update the commit message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment in FoldReifiedShape
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast, | ||
FoldOrthogonalPaddings, FoldStaticPadding, | ||
FoldConsecutiveConstantPadding>(context); | ||
FoldConsecutiveConstantPadding, FoldReifiedShape>(context); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please dont add as canonicalization. I think we decided that we could just used dim op folding patterns as and when needed instead of running it as a canonicalization without control. I think this pattern is unnecessary really.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I'm not mistaken -I might be missing something, the test added by Nico is not covered by any transforms/passes/canonicalizers upstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, --linalg-fold-unit-extent-dims --canonicalize
without the folder is getting there. So I might be missing another transform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok no, --linalg-fold-unit-extent-dims --canonicalize
doesn't work in all cases. I can't find a transform to simplify without Nico's canonicalization pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try --resolve-ranked-shaped-type-result-dims
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also tried that one and it doesn't do the trick.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tomorrow I'll take a look if I can augment the pass and get the desired behavior because I agree this doesn't really fit as a canonicalization. The main thing we want from this transform is that in some circumstances when we tile we can get static shapes after applying this, this is shown in the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MaheshRavishankar is there any written discussion about that decision? I'd like to better understand the needs here, and why we keep adding specialized passes, and hopefully figure out a more general upstream solution.
@MaheshRavishankar @matthias-springer I think we may need a longer alignment session here but TL;DR my position atm is:
If I wanted to be creative I could slash at the APIs to make this not generate IR and just pop out the constants (I am unsure the API changes will be desired though). However the Taking the extreme opposite view, I would ask why is a tensor/memref.CastOp ever considered a canonicalization? In general, looking at these issues more deeply after 2y, I like how the reification APIs and interfaces have evolved. This is all "modulo compiler cost"; I did not go as far as putting ValueBoundsOpInterface part of canonicalization but I would claim we should have an "expensive canonicalization" or something rather than have to discover a set of transforms that need to be applied together for this purpose of finding more static information. To further supplement my point, I'll claim that a lot of the static behavior that we enjoy today comes from Anyway, not trying to start an RFC in a PR or an ODM but a discussion on "is making shapes more static" a canonicalization and "creating IR that can be reverted" would be useful principles to have a general agreement on. @joker-eph @rengolin @ftynse for their take too. |
I would say @llvm/mlir-area-team if we are at this point. |
Just listened to https://www.youtube.com/watch?v=929JVLlbtRU |
Extracted out #145867 for the immediate functionality and this one can be put on the backburner. |
Don't recall all the details, but normally we shouldn't have non-deterministic behavior in the compiler. Is this due to something like the map of dialects or ops being a hashmap or something like that? |
I generally think the canonical form in MLIR is moot, given that we can add arbitrary operations and run transformations in vastly different paradigms (trivial example, LICM/CSE/GVN was harming affine transformations so much they had to undo it in Polly: https://dl.acm.org/doi/abs/10.1145/3168815, but these are generally considered "useful simplifications" for regular SSA; upstream example, we push constants and some operations back into gpu kernels to undo CSE) that may have different pre-requirements. We should rather move to a set of composable "normal forms" that have well-defined, programmatically verifiable invariants, and passes can declare which of the "normal forms" they expect and produce.
I think this is precisely a discussion that should happen at an ODM. This is about design, and we need high-bandwidth discussion. ODM is not a seminar were people present finished work, or rather, it should be that. It is also worth surfacing this discussion to the forum, there was a related topic recently. I feel like I've written the text from above at least twice in the past month or so. |
yeah happy to discuss this in an informal ODM, what I meant is I am not trying to litigate this as an RFC or ODM in this PR :p |
…and add a PadOp::FoldReifiedShape canonicalization