Skip to content

Commit 6086c27

Browse files
[mlir][memref] Fix out-of-bounds crash when reifying result dims (#70774)
Do not crash when the input IR is invalid, i.e., when the index of the dimension operand of a `tensor.dim`/`memref.dim` is out-of-bounds. This fixes #70180.
1 parent dbd4a0d commit 6086c27

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

Diff for: mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ struct DimOfReifyRankedShapedTypeOpInterface : public OpRewritePattern<OpTy> {
9494
reifiedResultShapes)))
9595
return failure();
9696
unsigned resultNumber = dimValue.getResultNumber();
97+
// Do not apply pattern if the IR is invalid (dim out of bounds).
98+
if (*dimIndex >= reifiedResultShapes[resultNumber].size())
99+
return rewriter.notifyMatchFailure(dimOp, "dimension is out of bounds");
97100
Value replacement = getValueOrCreateConstantIndexOp(
98101
rewriter, dimOp.getLoc(), reifiedResultShapes[resultNumber][*dimIndex]);
99102
rewriter.replaceOp(dimOp, replacement);

Diff for: mlir/test/Dialect/MemRef/resolve-dim-ops.mlir

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// RUN: mlir-opt --resolve-ranked-shaped-type-result-dims --split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: func @dim_out_of_bounds(
4+
// CHECK-NEXT: arith.constant
5+
// CHECK-NEXT: memref.dim
6+
// CHECK-NEXT: return
7+
func.func @dim_out_of_bounds(%m : memref<7x8xf32>) -> index {
8+
%idx = arith.constant 7 : index
9+
%0 = memref.dim %m, %idx : memref<7x8xf32>
10+
return %0 : index
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: func @dim_out_of_bounds_2(
16+
// CHECK-NEXT: arith.constant
17+
// CHECK-NEXT: arith.constant
18+
// CHECK-NEXT: bufferization.alloc_tensor
19+
// CHECK-NEXT: tensor.dim
20+
// CHECK-NEXT: return
21+
func.func @dim_out_of_bounds_2(%idx1 : index, %idx2 : index) -> index {
22+
%idx = arith.constant 7 : index
23+
%sz = arith.constant 5 : index
24+
%alloc = bufferization.alloc_tensor(%sz, %sz) : tensor<?x?xf32>
25+
%0 = tensor.dim %alloc, %idx : tensor<?x?xf32>
26+
return %0 : index
27+
}

0 commit comments

Comments
 (0)