Skip to content

Commit 0862760

Browse files
[mlir][tensor] Make tensor::PadOp a ReifyRankedShapedTypeOpInterface and add a PadOp::FoldReifiedShape canonicalizatio
1 parent 6dad1e8 commit 0862760

File tree

4 files changed

+89
-1
lines changed

4 files changed

+89
-1
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,6 +1256,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
12561256

12571257
def Tensor_PadOp : Tensor_Op<"pad", [
12581258
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
1259+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
12591260
AttrSizedOperandSegments,
12601261
Pure,
12611262
SingleBlockImplicitTerminator<"mlir::tensor::YieldOp">]> {

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ OpFoldResult getAsOpFoldResult(Value val);
9898
SmallVector<OpFoldResult> getAsOpFoldResult(ValueRange values);
9999
/// Convert `arrayAttr` to a vector of OpFoldResult.
100100
SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr);
101+
// TODO: implement a mixed form of this and deprecate getMixedPadImpl.
102+
// SmallVector<OpFoldResult> getAsOpFoldResult(ArrayAttr arrayAttr, ValueRange
103+
// values);
101104

102105
/// Convert int64_t to integer attributes of index type and return them as
103106
/// OpFoldResult.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Affine/Utils.h"
1011
#include "mlir/Dialect/Arith/IR/Arith.h"
1112
#include "mlir/Dialect/Arith/Utils/Utils.h"
1213
#include "mlir/Dialect/Complex/IR/Complex.h"
@@ -3791,13 +3792,78 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
37913792
}
37923793
};
37933794

3795+
struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> {
3796+
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3797+
3798+
LogicalResult matchAndRewrite(tensor::PadOp padOp,
3799+
PatternRewriter &rewriter) const override {
3800+
if (padOp.getNofold()) {
3801+
return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
3802+
}
3803+
3804+
ReifiedRankedShapedTypeDims reifiedResultShapes;
3805+
if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes)))
3806+
return failure();
3807+
3808+
SmallVector<int64_t> newShape;
3809+
for (const auto &[s, ofr] : llvm::zip_equal(
3810+
padOp.getResultType().getShape(), reifiedResultShapes.front())) {
3811+
std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
3812+
// Reification does not add static information, just use existing shape.
3813+
if (!maybeCst.has_value()) {
3814+
newShape.push_back(s);
3815+
continue;
3816+
}
3817+
int64_t cst = *maybeCst;
3818+
assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!");
3819+
newShape.push_back(cst);
3820+
}
3821+
if (newShape == padOp.getResultType().getShape())
3822+
return failure();
3823+
3824+
Type oldType = padOp.getResultType();
3825+
Type newType =
3826+
RankedTensorType::Builder(padOp.getResultType()).setShape(newShape);
3827+
Location loc = padOp->getLoc();
3828+
Operation *newPad = rewriter.clone(*padOp);
3829+
newPad->getResult(0).setType(newType);
3830+
rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType,
3831+
newPad->getResult(0));
3832+
return success();
3833+
}
3834+
};
3835+
37943836
} // namespace
37953837

3838+
LogicalResult
3839+
PadOp::reifyResultShapes(OpBuilder &b,
3840+
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3841+
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
3842+
SmallVector<OpFoldResult> lp = getMixedLowPad();
3843+
SmallVector<OpFoldResult> hp = getMixedHighPad();
3844+
for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3845+
if (!getType().isDynamicDim(i)) {
3846+
reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i));
3847+
continue;
3848+
}
3849+
Location loc = getLoc();
3850+
Value dim = b.createOrFold<tensor::DimOp>(
3851+
loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i));
3852+
3853+
affine::AffineBuilder ab(b, loc);
3854+
AffineExpr d0, d1, d2;
3855+
bindDims(b.getContext(), d0, d1, d2);
3856+
reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply(
3857+
b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3858+
}
3859+
return success();
3860+
}
3861+
37963862
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
37973863
MLIRContext *context) {
37983864
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
37993865
FoldOrthogonalPaddings, FoldStaticPadding,
3800-
FoldConsecutiveConstantPadding>(context);
3866+
FoldConsecutiveConstantPadding, FoldReifiedShape>(context);
38013867
}
38023868

38033869
/// Return the padding value of the PadOp if it constant. In this context,

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2543,3 +2543,21 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index,
25432543
// CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]]
25442544
// CHECK-SAME: tensor<?x?x10xf32> to tensor<?x?x?xf32>
25452545
// CHECK: return %[[RES]]
2546+
2547+
// -----
2548+
2549+
// CHECK-LABEL: func.func @pad_reification
2550+
func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>)
2551+
-> tensor<1x?x64xf32> {
2552+
%pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx)
2553+
%es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32>
2554+
2555+
// CHECK: tensor.pad
2556+
// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32>
2557+
%padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] {
2558+
^bb0(%a: index, %b: index, %c: index):
2559+
tensor.yield %cst : f32
2560+
} : tensor<1x?x64xf32> to tensor<1x?x64xf32>
2561+
2562+
return %padded : tensor<1x?x64xf32>
2563+
}

0 commit comments

Comments
 (0)