@@ -1770,13 +1770,39 @@ class StridedSliceConstantMaskFolder final
17701770 }
17711771};
17721772
1773+ // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
1774+ class StridedSliceConstantFolder final
1775+ : public OpRewritePattern<ExtractStridedSliceOp> {
1776+ public:
1777+ using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
1778+
1779+ LogicalResult matchAndRewrite (ExtractStridedSliceOp extractStridedSliceOp,
1780+ PatternRewriter &rewriter) const override {
1781+ // Return if 'extractStridedSliceOp' operand is not defined by a
1782+ // ConstantOp.
1783+ auto constantOp =
1784+ extractStridedSliceOp.vector ().getDefiningOp <ConstantOp>();
1785+ if (!constantOp)
1786+ return failure ();
1787+ auto dense = constantOp.value ().dyn_cast <SplatElementsAttr>();
1788+ if (!dense)
1789+ return failure ();
1790+ auto newAttr = DenseElementsAttr::get (
1791+ extractStridedSliceOp.getType ().cast <VectorType>(),
1792+ dense.getSplatValue ());
1793+ rewriter.replaceOpWithNewOp <ConstantOp>(extractStridedSliceOp, newAttr);
1794+ return success ();
1795+ }
1796+ };
1797+
17731798} // end anonymous namespace
17741799
17751800void ExtractStridedSliceOp::getCanonicalizationPatterns (
17761801 OwningRewritePatternList &results, MLIRContext *context) {
17771802 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
1778- // ConstantMaskOp.
1779- results.insert <StridedSliceConstantMaskFolder>(context);
1803+ // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
1804+ results.insert <StridedSliceConstantMaskFolder, StridedSliceConstantFolder>(
1805+ context);
17801806}
17811807
17821808// ===----------------------------------------------------------------------===//
@@ -2560,6 +2586,36 @@ OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
25602586 return {};
25612587}
25622588
2589+ namespace {
2590+ // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
2591+ class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
2592+ public:
2593+ using OpRewritePattern<ShapeCastOp>::OpRewritePattern;
2594+
2595+ LogicalResult matchAndRewrite (ShapeCastOp shapeCastOp,
2596+ PatternRewriter &rewriter) const override {
2597+ auto constantOp = shapeCastOp.source ().getDefiningOp <ConstantOp>();
2598+ if (!constantOp)
2599+ return failure ();
2600+ // Only handle splat for now.
2601+ auto dense = constantOp.value ().dyn_cast <SplatElementsAttr>();
2602+ if (!dense)
2603+ return failure ();
2604+ auto newAttr = DenseElementsAttr::get (
2605+ shapeCastOp.getType ().cast <VectorType>(), dense.getSplatValue ());
2606+ rewriter.replaceOpWithNewOp <ConstantOp>(shapeCastOp, newAttr);
2607+ return success ();
2608+ }
2609+ };
2610+
2611+ } // namespace
2612+
2613+ void ShapeCastOp::getCanonicalizationPatterns (OwningRewritePatternList &results,
2614+ MLIRContext *context) {
2615+ // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp.
2616+ results.insert <ShapeCastConstantFolder>(context);
2617+ }
2618+
25632619// ===----------------------------------------------------------------------===//
25642620// VectorBitCastOp
25652621// ===----------------------------------------------------------------------===//
0 commit comments