|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 |
|
9 | 9 | #include "mlir/Dialect/Affine/IR/AffineOps.h"
|
| 10 | +#include "mlir/Dialect/Affine/Utils.h" |
10 | 11 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
11 | 12 | #include "mlir/Dialect/Arith/Utils/Utils.h"
|
12 | 13 | #include "mlir/Dialect/Complex/IR/Complex.h"
|
|
38 | 39 | #include "llvm/Support/LogicalResult.h"
|
39 | 40 | #include "llvm/Support/MathExtras.h"
|
40 | 41 | #include <algorithm>
|
| 42 | +#include <cstdint> |
41 | 43 | #include <optional>
|
42 | 44 | #include <vector>
|
43 | 45 |
|
@@ -3791,13 +3793,78 @@ struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
|
3791 | 3793 | }
|
3792 | 3794 | };
|
3793 | 3795 |
|
| 3796 | +struct FoldReifiedShape : public OpRewritePattern<tensor::PadOp> { |
| 3797 | + using OpRewritePattern<tensor::PadOp>::OpRewritePattern; |
| 3798 | + |
| 3799 | + LogicalResult matchAndRewrite(tensor::PadOp padOp, |
| 3800 | + PatternRewriter &rewriter) const override { |
| 3801 | + if (padOp.getNofold()) { |
| 3802 | + return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad"); |
| 3803 | + } |
| 3804 | + |
| 3805 | + ReifiedRankedShapedTypeDims reifiedResultShapes; |
| 3806 | + if (failed(reifyResultShapes(rewriter, padOp, reifiedResultShapes))) |
| 3807 | + return failure(); |
| 3808 | + |
| 3809 | + SmallVector<int64_t> newShape; |
| 3810 | + for (const auto &[s, ofr] : llvm::zip_equal( |
| 3811 | + padOp.getResultType().getShape(), reifiedResultShapes.front())) { |
| 3812 | + std::optional<int64_t> maybeCst = getConstantIntValue(ofr); |
| 3813 | + // Reification does not add static information, just use existing shape. |
| 3814 | + if (!maybeCst.has_value()) { |
| 3815 | + newShape.push_back(s); |
| 3816 | + continue; |
| 3817 | + } |
| 3818 | + int64_t cst = *maybeCst; |
| 3819 | + assert((ShapedType::isDynamic(s) || s == cst) && "constants must agree!"); |
| 3820 | + newShape.push_back(cst); |
| 3821 | + } |
| 3822 | + if (newShape == padOp.getResultType().getShape()) |
| 3823 | + return failure(); |
| 3824 | + |
| 3825 | + Type oldType = padOp.getResultType(); |
| 3826 | + Type newType = |
| 3827 | + RankedTensorType::Builder(padOp.getResultType()).setShape(newShape); |
| 3828 | + Location loc = padOp->getLoc(); |
| 3829 | + Operation *newPad = rewriter.clone(*padOp); |
| 3830 | + newPad->getResult(0).setType(newType); |
| 3831 | + rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldType, |
| 3832 | + newPad->getResult(0)); |
| 3833 | + return success(); |
| 3834 | + } |
| 3835 | +}; |
| 3836 | + |
3794 | 3837 | } // namespace
|
3795 | 3838 |
|
| 3839 | +LogicalResult |
| 3840 | +PadOp::reifyResultShapes(OpBuilder &b, |
| 3841 | + ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
| 3842 | + reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank())); |
| 3843 | + SmallVector<OpFoldResult> lp = getMixedLowPad(); |
| 3844 | + SmallVector<OpFoldResult> hp = getMixedHighPad(); |
| 3845 | + for (int64_t i = 0; i < getResultType().getRank(); ++i) { |
| 3846 | + if (!getType().isDynamicDim(i)) { |
| 3847 | + reifiedReturnShapes[0][i] = b.getIndexAttr(getType().getDimSize(i)); |
| 3848 | + continue; |
| 3849 | + } |
| 3850 | + Location loc = getLoc(); |
| 3851 | + Value dim = b.createOrFold<tensor::DimOp>( |
| 3852 | + loc, getSource(), b.create<arith::ConstantIndexOp>(loc, i)); |
| 3853 | + |
| 3854 | + affine::AffineBuilder ab(b, loc); |
| 3855 | + AffineExpr d0, d1, d2; |
| 3856 | + bindDims(b.getContext(), d0, d1, d2); |
| 3857 | + reifiedReturnShapes[0][i] = affine::makeComposedFoldedAffineApply( |
| 3858 | + b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]}); |
| 3859 | + } |
| 3860 | + return success(); |
| 3861 | +} |
| 3862 | + |
3796 | 3863 | void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
3797 | 3864 | MLIRContext *context) {
|
3798 | 3865 | results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
|
3799 | 3866 | FoldOrthogonalPaddings, FoldStaticPadding,
|
3800 |
| - FoldConsecutiveConstantPadding>(context); |
| 3867 | + FoldConsecutiveConstantPadding, FoldReifiedShape>(context); |
3801 | 3868 | }
|
3802 | 3869 |
|
3803 | 3870 | /// Return the padding value of the PadOp if it constant. In this context,
|
|
0 commit comments