Skip to content

Commit 59ca35d

Browse files
[mlir][tensor] Make tensor::PadOp a ReifyRankedShapedTypeOpInterface and add a PadOp::FoldReifiedShape canonicalizatio
1 parent 6dad1e8 commit 59ca35d

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
12561256

12571257
def Tensor_PadOp : Tensor_Op<"pad", [
12581258
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1259+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
12591260
AttrSizedOperandSegments,
12601261
Pure,
12611262
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ OpFoldResult getAsOpFoldResult(Value val);
9898
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
9999
/// Convert `arrayAttr` to a vector of OpFoldResult.
100100
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
101+
// TODO: implement a mixed form of this and deprecate getMixedPadImpl.
102+
// SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr, ValueRange
103+
// values);
101104

102105
/// Convert int64_t to integer attributes of index type and return them as
103106
/// OpFoldResult.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Affine/Utils.h"
1011
#include "mlir/Dialect/Arith/IR/Arith.h"
1112
#include "mlir/Dialect/Arith/Utils/Utils.h"
1213
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -38,6 +39,7 @@
3839
#include "llvm/Support/LogicalResult.h"
3940
#include "llvm/Support/MathExtras.h"
4041
#include <algorithm>
42+
#include <cstdint>
4143
#include <optional>
4244
#include <vector>
4345

@@ -3791,13 +3793,78 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
37913793
}
37923794
};
37933795

3796+
struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> {
3797+
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3798+
3799+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
3800+
PatternRewriter &rewriter) const override {
3801+
if (padOp.getNofold()) {
3802+
return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3803+
}
3804+
3805+
ReifiedRankedShapedTypeDims reifiedResultShapes;
3806+
if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes)))
3807+
return failure();
3808+
3809+
SmallVector<int64_t> newShape;
3810+
for (const auto &[s, ofr] : llvm::zip_equal(
3811+
padOp.getResultType().getShape(), reifiedResultShapes.front())) {
3812+
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
3813+
// Reification does not add static information, just use existing shape.
3814+
if (!maybeCst.has_value()) {
3815+
newShape.push_back(s);
3816+
continue;
3817+
}
3818+
int64_t cst = *maybeCst;
3819+
assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!");
3820+
newShape.push_back(cst);
3821+
}
3822+
if (newShape == padOp.getResultType().getShape())
3823+
return failure();
3824+
3825+
Type oldType = padOp.getResultType();
3826+
Type newType =
3827+
RankedTensorType::Builder(padOp.getResultType()).setShape(newShape);
3828+
Location loc = padOp->getLoc();
3829+
Operation *newPad = rewriter.clone(*padOp);
3830+
newPad->getResult(0).setType(newType);
3831+
rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType,
3832+
newPad->getResult(0));
3833+
return success();
3834+
}
3835+
};
3836+
37943837
} // namespace
37953838

3839+
LogicalResult
3840+
PadOp::reifyResultShapes(OpBuilder &b,
3841+
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3842+
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3843+
SmallVector<OpFoldResult> lp = getMixedLowPad();
3844+
SmallVector<OpFoldResult> hp = getMixedHighPad();
3845+
for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3846+
if (!getType().isDynamicDim(i)) {
3847+
reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3848+
continue;
3849+
}
3850+
Location loc = getLoc();
3851+
Value dim = b.createOrFold<tensor::DimOp>(
3852+
loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
3853+
3854+
affine::AffineBuilder ab(b, loc);
3855+
AffineExpr d0, d1, d2;
3856+
bindDims(b.getContext(), d0, d1, d2);
3857+
reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3858+
b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3859+
}
3860+
return success();
3861+
}
3862+
37963863
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
37973864
MLIRContext *context) {
37983865
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
37993866
FoldOrthogonalPaddings, FoldStaticPadding,
3800-
FoldConsecutiveConstantPadding>(context);
3867+
FoldConsecutiveConstantPadding, FoldReifiedShape>(context);
38013868
}
38023869

38033870
/// Return the padding value of the PadOp if it constant. In this context,

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2543,3 +2543,21 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
25432543
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
25442544
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
25452545
// CHECK: return %[[RES]]
2546+
2547+
// -----
2548+
2549+
// CHECK-LABEL: func.func @pad_reification
2550+
func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
2551+
-> tensor<1x?x64xf32> {
2552+
%pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
2553+
%es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
2554+
2555+
// CHECK: tensor.pad
2556+
// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
2557+
%padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
2558+
^bb0(%a: index, %b: index, %c: index):
2559+
tensor.yield %cst : f32
2560+
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
2561+
2562+
return %padded : tensor<1x?x64xf32>
2563+
}

0 commit comments

Comments
 (0)