Skip to content

Commit 3648065

Browse files
committed
[mlir][vector] Add canonicalization patterns for ExtractStride/ShapeCast + Splat constant
Differential Revision: https://reviews.llvm.org/D90567
1 parent e969ab4 commit 3648065

File tree

3 files changed

+93
-2
lines changed

3 files changed

+93
-2
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,6 +1649,7 @@ def Vector_ShapeCastOp :
16491649
}];
16501650
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
16511651
let hasFolder = 1;
1652+
let hasCanonicalizer = 1;
16521653
}
16531654

16541655
def Vector_BitCastOp :

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,13 +1770,39 @@ class StridedSliceConstantMaskFolder final
17701770
}
17711771
};
17721772

1773+
// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
1774+
class StridedSliceConstantFolder final
1775+
: public OpRewritePattern<ExtractStridedSliceOp> {
1776+
public:
1777+
using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1778+
1779+
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
1780+
PatternRewriter &rewriter) const override {
1781+
// Return if 'extractStridedSliceOp' operand is not defined by a
1782+
// ConstantOp.
1783+
auto constantOp =
1784+
extractStridedSliceOp.vector().getDefiningOp<ConstantOp>();
1785+
if (!constantOp)
1786+
return failure();
1787+
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
1788+
if (!dense)
1789+
return failure();
1790+
auto newAttr = DenseElementsAttr::get(
1791+
extractStridedSliceOp.getType().cast<VectorType>(),
1792+
dense.getSplatValue());
1793+
rewriter.replaceOpWithNewOp<ConstantOp>(extractStridedSliceOp, newAttr);
1794+
return success();
1795+
}
1796+
};
1797+
17731798
} // end anonymous namespace
17741799

17751800
void ExtractStridedSliceOp::getCanonicalizationPatterns(
17761801
OwningRewritePatternList &results, MLIRContext *context) {
17771802
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
1778-
// ConstantMaskOp.
1779-
results.insert<StridedSliceConstantMaskFolder>(context);
1803+
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
1804+
results.insert<StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
1805+
context);
17801806
}
17811807

17821808
//===----------------------------------------------------------------------===//
@@ -2560,6 +2586,36 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
25602586
return {};
25612587
}
25622588

2589+
namespace {
2590+
// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
2591+
class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
2592+
public:
2593+
using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
2594+
2595+
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
2596+
PatternRewriter &rewriter) const override {
2597+
auto constantOp = shapeCastOp.source().getDefiningOp<ConstantOp>();
2598+
if (!constantOp)
2599+
return failure();
2600+
// Only handle splat for now.
2601+
auto dense = constantOp.value().dyn_cast<SplatElementsAttr>();
2602+
if (!dense)
2603+
return failure();
2604+
auto newAttr = DenseElementsAttr::get(
2605+
shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
2606+
rewriter.replaceOpWithNewOp<ConstantOp>(shapeCastOp, newAttr);
2607+
return success();
2608+
}
2609+
};
2610+
2611+
} // namespace
2612+
2613+
void ShapeCastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2614+
MLIRContext *context) {
2615+
// Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
2616+
results.insert<ShapeCastConstantFolder>(context);
2617+
}
2618+
25632619
//===----------------------------------------------------------------------===//
25642620
// VectorBitCastOp
25652621
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,3 +580,37 @@ func @broadcast_folding2() -> vector<4x16xi32> {
580580
%2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32>
581581
return %2 : vector<4x16xi32>
582582
}
583+
584+
// -----
585+
586+
// CHECK-LABEL: shape_cast_constant
587+
// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32>
588+
// CHECK: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32>
589+
// CHECK: return %[[CST0]], %[[CST1]] : vector<20x2xf32>, vector<3x4x2xi32>
590+
func @shape_cast_constant() -> (vector<20x2xf32>, vector<3x4x2xi32>) {
591+
%cst = constant dense<2.000000e+00> : vector<5x4x2xf32>
592+
%cst_1 = constant dense<1> : vector<12x2xi32>
593+
%0 = vector.shape_cast %cst : vector<5x4x2xf32> to vector<20x2xf32>
594+
%1 = vector.shape_cast %cst_1 : vector<12x2xi32> to vector<3x4x2xi32>
595+
return %0, %1 : vector<20x2xf32>, vector<3x4x2xi32>
596+
}
597+
598+
// -----
599+
600+
// CHECK-LABEL: extract_strided_constant
601+
// CHECK: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<12x2xf32>
602+
// CHECK: %[[CST1:.*]] = constant dense<1> : vector<2x13x3xi32>
603+
// CHECK: return %[[CST0]], %[[CST1]] : vector<12x2xf32>, vector<2x13x3xi32>
604+
func @extract_strided_constant() -> (vector<12x2xf32>, vector<2x13x3xi32>) {
605+
%cst = constant dense<2.000000e+00> : vector<29x7xf32>
606+
%cst_1 = constant dense<1> : vector<4x37x9xi32>
607+
%0 = vector.extract_strided_slice %cst
608+
{offsets = [2, 3], sizes = [12, 2], strides = [1, 1]}
609+
: vector<29x7xf32> to vector<12x2xf32>
610+
%1 = vector.extract_strided_slice %cst_1
611+
{offsets = [1, 2, 5], sizes = [2, 13, 3], strides = [1, 1, 1]}
612+
: vector<4x37x9xi32> to vector<2x13x3xi32>
613+
return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32>
614+
}
615+
616+

0 commit comments

Comments
 (0)