Skip to content

Commit 8e7e9d4

Browse files
authored
[MemoryLocation] Support strided matrix loads / stores (#163368)
This patch provides an approximation of the memory locations touched by `llvm.matrix.column.major.load` and `llvm.matrix.column.major.store`, enabling dead store elimination and GVN to remove redundant loads and dead stores. PR: #163368
1 parent 8c04420 commit 8e7e9d4

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

llvm/lib/Analysis/MemoryLocation.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,34 @@ MemoryLocation MemoryLocation::getForArgument(const CallBase *Call,
288288
LocationSize::precise(DL.getTypeStoreSize(
289289
II->getArgOperand(1)->getType())),
290290
AATags);
291+
case Intrinsic::matrix_column_major_load:
292+
case Intrinsic::matrix_column_major_store: {
293+
bool IsLoad = II->getIntrinsicID() == Intrinsic::matrix_column_major_load;
294+
assert(ArgIdx == (IsLoad ? 0 : 1) && "Invalid argument index");
295+
296+
auto *Stride = dyn_cast<ConstantInt>(II->getArgOperand(IsLoad ? 1 : 2));
297+
uint64_t Rows =
298+
cast<ConstantInt>(II->getArgOperand(IsLoad ? 3 : 4))->getZExtValue();
299+
uint64_t Cols =
300+
cast<ConstantInt>(II->getArgOperand(IsLoad ? 4 : 5))->getZExtValue();
301+
302+
// The stride is dynamic, so there's nothing we can say.
303+
if (!Stride)
304+
return MemoryLocation(Arg, LocationSize::afterPointer(), AATags);
305+
306+
uint64_t ConstStride = Stride->getZExtValue();
307+
auto *VT = cast<VectorType>(IsLoad ? II->getType()
308+
: II->getArgOperand(0)->getType());
309+
assert(Cols != 0 && "Matrix cannot have 0 columns");
310+
TypeSize Size = DL.getTypeAllocSize(VT->getScalarType()) *
311+
(ConstStride * (Cols - 1) + Rows);
312+
313+
// In the unstrided case, we have a precise size, ...
314+
if (ConstStride == Rows)
315+
return MemoryLocation(Arg, LocationSize::precise(Size), AATags);
316+
// otherwise we merely obtain an upper bound.
317+
return MemoryLocation(Arg, LocationSize::upperBound(Size), AATags);
318+
}
291319
}
292320

293321
assert(

llvm/test/Transforms/DeadStoreElimination/matrix-intrinsics.ll

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ define void @dead_unstrided_store_non_matrix_load(ptr noalias %src, ptr noalias
55
; CHECK-LABEL: define void @dead_unstrided_store_non_matrix_load(
66
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
77
; CHECK-NEXT: [[ENTRY:.*:]]
8-
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
98
; CHECK-NEXT: [[L:%.*]] = load double, ptr [[SRC]], align 8
9+
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
1010
; CHECK-NEXT: ret void
1111
;
1212
entry:
@@ -173,7 +173,6 @@ define void @dead_unstrided_store(ptr noalias %src, ptr noalias %dst) {
173173
; CHECK-LABEL: define void @dead_unstrided_store(
174174
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
175175
; CHECK-NEXT: [[ENTRY:.*:]]
176-
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
177176
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 4, i1 false, i32 4, i32 2)
178177
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
179178
; CHECK-NEXT: ret void
@@ -241,7 +240,6 @@ define void @dead_matrix_store_non_matrix_overwrite_unstrided(ptr noalias %src,
241240
; CHECK-LABEL: define void @dead_matrix_store_non_matrix_overwrite_unstrided(
242241
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
243242
; CHECK-NEXT: [[ENTRY:.*:]]
244-
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
245243
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 4, i1 false, i32 4, i32 2)
246244
; CHECK-NEXT: store <8 x double> zeroinitializer, ptr [[DST]], align 64
247245
; CHECK-NEXT: ret void
@@ -257,7 +255,6 @@ define void @dead_matrix_store_non_matrix_overwrite_strided(ptr noalias %src, pt
257255
; CHECK-LABEL: define void @dead_matrix_store_non_matrix_overwrite_strided(
258256
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
259257
; CHECK-NEXT: [[ENTRY:.*:]]
260-
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
261258
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
262259
; CHECK-NEXT: store <16 x double> zeroinitializer, ptr [[DST]], align 128
263260
; CHECK-NEXT: ret void
@@ -289,7 +286,6 @@ define void @live_matrix_store_non_matrix_overwrite_strided(ptr noalias %src, pt
289286
; CHECK-LABEL: define void @live_matrix_store_non_matrix_overwrite_strided(
290287
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
291288
; CHECK-NEXT: [[ENTRY:.*:]]
292-
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> zeroinitializer, ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
293289
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
294290
; CHECK-NEXT: store <8 x double> zeroinitializer, ptr [[DST]], align 64
295291
; CHECK-NEXT: ret void
@@ -305,8 +301,6 @@ define void @dead_matrix_store_dimension_change(ptr noalias %src, ptr noalias %d
305301
; CHECK-LABEL: define void @dead_matrix_store_dimension_change(
306302
; CHECK-SAME: ptr noalias [[SRC:%.*]], ptr noalias [[DST:%.*]]) {
307303
; CHECK-NEXT: [[ENTRY:.*:]]
308-
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
309-
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[DST]], i32 4, i1 false, i32 4, i32 2)
310304
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v9f64.i32(<9 x double> zeroinitializer, ptr [[DST]], i32 3, i1 false, i32 3, i32 3)
311305
; CHECK-NEXT: ret void
312306
;

llvm/test/Transforms/GVN/matrix-intrinsics.ll

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ define void @redundant_unstrided_load(ptr %src) {
88
; CHECK-NEXT: [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 8
99
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
1010
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[SRC]], i32 4, i1 false, i32 4, i32 2)
11-
; CHECK-NEXT: [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
1211
; CHECK-NEXT: call void @use(<8 x double> [[L]])
13-
; CHECK-NEXT: call void @use(<8 x double> [[L_2]])
12+
; CHECK-NEXT: call void @use(<8 x double> [[L]])
1413
; CHECK-NEXT: ret void
1514
;
1615
entry:
@@ -30,9 +29,8 @@ define void @redundant_unstrided_load_non_matrix_store(ptr %src) {
3029
; CHECK-NEXT: [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 1
3130
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
3231
; CHECK-NEXT: store double 4.200000e+01, ptr [[SRC]], align 8
33-
; CHECK-NEXT: [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 4, i1 false, i32 4, i32 2)
3432
; CHECK-NEXT: call void @use(<8 x double> [[L]])
35-
; CHECK-NEXT: call void @use(<8 x double> [[L_2]])
33+
; CHECK-NEXT: call void @use(<8 x double> [[L]])
3634
; CHECK-NEXT: ret void
3735
;
3836
entry:
@@ -52,9 +50,8 @@ define void @redundant_strided_load(ptr %src) {
5250
; CHECK-NEXT: [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 16
5351
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
5452
; CHECK-NEXT: call void @llvm.matrix.column.major.store.v8f64.i32(<8 x double> [[L]], ptr [[SRC]], i32 8, i1 false, i32 4, i32 2)
55-
; CHECK-NEXT: [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
5653
; CHECK-NEXT: call void @use(<8 x double> [[L]])
57-
; CHECK-NEXT: call void @use(<8 x double> [[L_2]])
54+
; CHECK-NEXT: call void @use(<8 x double> [[L]])
5855
; CHECK-NEXT: ret void
5956
;
6057
entry:
@@ -75,9 +72,8 @@ define void @redundant_strided_load_non_matrix_store(ptr %src) {
7572
; CHECK-NEXT: [[SRC_OFFSET:%.*]] = getelementptr inbounds double, ptr [[SRC]], i32 16
7673
; CHECK-NEXT: [[L:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
7774
; CHECK-NEXT: store double 4.200000e+01, ptr [[SRC]], align 8
78-
; CHECK-NEXT: [[L_2:%.*]] = call <8 x double> @llvm.matrix.column.major.load.v8f64.i32(ptr [[SRC_OFFSET]], i32 8, i1 false, i32 4, i32 2)
7975
; CHECK-NEXT: call void @use(<8 x double> [[L]])
80-
; CHECK-NEXT: call void @use(<8 x double> [[L_2]])
76+
; CHECK-NEXT: call void @use(<8 x double> [[L]])
8177
; CHECK-NEXT: ret void
8278
;
8379
entry:

0 commit comments

Comments
 (0)