diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 22a25fd1a5af8..c3147e297e2ff 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3791,6 +3791,47 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern { } }; +struct FoldReifiedShape : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PadOp padOp, + PatternRewriter &rewriter) const override { + if (padOp.getNofold()) { + return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad"); + } + + ReifiedRankedShapedTypeDims reifiedResultShapes; + if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes))) + return failure(); + + SmallVector newShape; + for (const auto &[s, ofr] : llvm::zip_equal( + padOp.getResultType().getShape(), reifiedResultShapes.front())) { + std::optional maybeCst = getConstantIntValue(ofr); + // Reification does not add static information, just use existing shape. + if (!maybeCst.has_value()) { + newShape.push_back(s); + continue; + } + int64_t cst = *maybeCst; + assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!"); + newShape.push_back(cst); + } + if (newShape == padOp.getResultType().getShape()) + return failure(); + + Type oldType = padOp.getResultType(); + Type newType = + RankedTensorType::Builder(padOp.getResultType()).setShape(newShape); + Location loc = padOp->getLoc(); + Operation *newPad = rewriter.clone(*padOp); + newPad->getResult(0).setType(newType); + rewriter.replaceOpWithNewOp(padOp, oldType, + newPad->getResult(0)); + return success(); + } +}; + } // namespace LogicalResult @@ -3820,7 +3861,7 @@ void PadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + FoldConsecutiveConstantPadding, FoldReifiedShape>(context); } /// Return the padding value of the PadOp if it constant. In this context, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 3f9236095138b..2a42a9a810ec4 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2561,3 +2561,21 @@ func.func @partial_sink_expand_of_cast(%arg0 : tensor<10x10xf32>, %arg1 : index, // CHECK: %[[RES:.+]] = tensor.cast %[[EXPAND]] // CHECK-SAME: tensor to tensor // CHECK: return %[[RES]] + +// ----- + +// CHECK-LABEL: func.func @pad_reification +func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) + -> tensor<1x?x64xf32> { + %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx) + %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] : tensor<64x?x64xf32> to tensor<1x?x64xf32> + +// CHECK: tensor.pad +// CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32> + %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] { + ^bb0(%a: index, %b: index, %c: index): + tensor.yield %cst : f32 + } : tensor<1x?x64xf32> to tensor<1x?x64xf32> + + return %padded : tensor<1x?x64xf32> +}