Skip to content

Commit c988bf8

Browse files
authored
[mlir][memref] Canonicalize memref.reinterpret_cast when offset/sizes/strides are constants. (#163505)
Implement folding logic to canonicalize memref.reinterpret_cast ops when offset, sizes and strides are compile-time constants. This removes dynamic shape annotations and produces a static memref form, allowing further lowering and backend optimizations.
1 parent 50d65a5 commit c988bf8

File tree

2 files changed

+56
-10
lines changed

2 files changed

+56
-10
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2158,11 +2158,45 @@ struct ReinterpretCastOpExtractStridedMetadataFolder
21582158
return success();
21592159
}
21602160
};
2161+
2162+
struct ReinterpretCastOpConstantFolder
2163+
: public OpRewritePattern<ReinterpretCastOp> {
2164+
public:
2165+
using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2166+
2167+
LogicalResult matchAndRewrite(ReinterpretCastOp op,
2168+
PatternRewriter &rewriter) const override {
2169+
unsigned srcStaticCount = llvm::count_if(
2170+
llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2171+
op.getMixedStrides()),
2172+
[](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2173+
2174+
SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2175+
SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2176+
SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2177+
2178+
// TODO: Using counting comparison instead of direct comparison because
2179+
// getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
2180+
// IntegerAttrs, while constifyIndexValues (and therefore
2181+
// ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
2182+
if (srcStaticCount ==
2183+
llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2184+
[](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
2185+
return failure();
2186+
2187+
auto newReinterpretCast = ReinterpretCastOp::create(
2188+
rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2189+
2190+
rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
2191+
return success();
2192+
}
2193+
};
21612194
} // namespace
21622195

21632196
void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
21642197
MLIRContext *context) {
2165-
results.add<ReinterpretCastOpExtractStridedMetadataFolder>(context);
2198+
results.add<ReinterpretCastOpExtractStridedMetadataFolder,
2199+
ReinterpretCastOpConstantFolder>(context);
21662200
}
21672201

21682202
FailureOr<std::optional<SmallVector<Value>>>

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,21 @@ func.func @reinterpret_noop(%arg : memref<2x3x4xf32>) -> memref<2x3x4xf32> {
911911

912912
// -----
913913

914+
// CHECK-LABEL: func @reinterpret_constant_fold
915+
// CHECK-SAME: (%[[ARG:.*]]: memref<f32>)
916+
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [100, 100], strides: [100, 1]
917+
// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
918+
// CHECK: return %[[CAST]]
919+
func.func @reinterpret_constant_fold(%arg0: memref<f32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
920+
%c0 = arith.constant 0 : index
921+
%c1 = arith.constant 1 : index
922+
%c100 = arith.constant 100 : index
923+
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [%c0], sizes: [%c100, %c100], strides: [%c100, %c1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
924+
return %reinterpret_cast : memref<?x?xf32, strided<[?, ?], offset: ?>>
925+
}
926+
927+
// -----
928+
914929
// CHECK-LABEL: func @reinterpret_of_reinterpret
915930
// CHECK-SAME: (%[[ARG:.*]]: memref<?xi8>, %[[SIZE1:.*]]: index, %[[SIZE2:.*]]: index)
916931
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [%[[SIZE2]]], strides: [1]
@@ -996,10 +1011,9 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref<?x?x
9961011
// when the strides don't match.
9971012
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_stride
9981013
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
999-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
1000-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
1001-
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [%[[C0]]], sizes: [4, 2, 2], strides: [1, 1, %[[C1]]]
1002-
// CHECK: return %[[RES]]
1014+
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [0], sizes: [4, 2, 2], strides: [1, 1, 1]
1015+
// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
1016+
// CHECK: return %[[CAST]]
10031017
func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>> {
10041018
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
10051019
%m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref<f32> to memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
@@ -1011,11 +1025,9 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me
10111025
// when the offset doesn't match.
10121026
// CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_different_offset
10131027
// CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>)
1014-
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
1015-
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
1016-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
1017-
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]]
1018-
// CHECK: return %[[RES]]
1028+
// CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [8, 2], strides: [2, 1]
1029+
// CHECK: %[[CAST:.*]] = memref.cast %[[RES]]
1030+
// CHECK: return %[[CAST]]
10191031
func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref<?x?xf32, strided<[?, ?], offset: ?>> {
10201032
%base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref<f32>, index, index, index, index, index
10211033
%m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref<f32> to memref<?x?xf32, strided<[?, ?], offset: ?>>

0 commit comments

Comments
 (0)