@@ -64,46 +64,6 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
64
64
return success ();
65
65
}
66
66
67
- // / Get pointer to a memref descriptor.
68
- // / Optionally, the base pointer can be offset using linearized index computed
69
- // / from the given indices.
70
- static Value getBufferPtr (Location loc, MemRefType type, Value buffer,
71
- ValueRange indices,
72
- const LLVMTypeConverter &typeConverter,
73
- RewriterBase &rewriter) {
74
- auto [strides, offset] = type.getStridesAndOffset ();
75
-
76
- MemRefDescriptor memRefDescriptor (buffer);
77
- Value base = memRefDescriptor.bufferPtr (rewriter, loc, typeConverter, type);
78
-
79
- int numIndices = indices.size ();
80
- if (numIndices == 0 )
81
- return base;
82
-
83
- assert (type.getRank () == numIndices &&
84
- " expects number of indices equal to memref rank" );
85
- Value index;
86
- Type indexType = typeConverter.getIndexType ();
87
- for (int i = 0 ; i < numIndices; ++i) {
88
- Value increment = indices[i];
89
- if (strides[i] != 1 ) { // Skip if stride is 1.
90
- Value stride =
91
- ShapedType::isDynamic (strides[i])
92
- ? memRefDescriptor.stride (rewriter, loc, i)
93
- : rewriter.create <LLVM::ConstantOp>(
94
- loc, indexType, rewriter.getIndexAttr (strides[i]));
95
- increment = rewriter.create <LLVM::MulOp>(loc, increment, stride);
96
- }
97
- index =
98
- index ? rewriter.create <LLVM::AddOp>(loc, index, increment) : increment;
99
- }
100
-
101
- Type elementPtrType = memRefDescriptor.getElementPtrType ();
102
- return rewriter.create <LLVM::GEPOp>(
103
- loc, elementPtrType, typeConverter.convertType (type.getElementType ()),
104
- base, index);
105
- }
106
-
107
67
// / Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
108
68
// / dimension directly translates into the number of rows of the tiles.
109
69
// / The second dimensions needs to be scaled by the number of bytes.
@@ -122,7 +82,6 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
122
82
123
83
// / Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
124
84
// / shape may "envelop" the actual tile shape, and may be dynamically sized.
125
- // / Returns failure if proper stride couldn't be found.
126
85
static Value getStride (Location loc, MemRefType mType , Value base,
127
86
RewriterBase &rewriter) {
128
87
assert (mType .getRank () >= 2 && " Invalid shape for AMX strides" );
@@ -184,8 +143,8 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
184
143
SmallVector<Value> intrinsicOperands;
185
144
intrinsicOperands.append (getTileSizes (loc, getTileType (), rewriter));
186
145
intrinsicOperands.push_back (
187
- getBufferPtr ( loc, getMemRefType (), adaptor. getBase (),
188
- adaptor.getIndices (), typeConverter, rewriter ));
146
+ LLVM::getStridedElementPtr (rewriter, loc, typeConverter, getMemRefType (),
147
+ adaptor.getBase (), adaptor. getIndices () ));
189
148
intrinsicOperands.push_back (
190
149
getStride (loc, getMemRefType (), adaptor.getBase (), rewriter));
191
150
@@ -217,8 +176,8 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
217
176
SmallVector<Value> intrinsicOperands;
218
177
intrinsicOperands.append (getTileSizes (loc, getTileType (), rewriter));
219
178
intrinsicOperands.push_back (
220
- getBufferPtr ( loc, getMemRefType (), adaptor. getBase (),
221
- adaptor.getIndices (), typeConverter, rewriter ));
179
+ LLVM::getStridedElementPtr (rewriter, loc, typeConverter, getMemRefType (),
180
+ adaptor.getBase (), adaptor. getIndices () ));
222
181
intrinsicOperands.push_back (
223
182
getStride (loc, getMemRefType (), adaptor.getBase (), rewriter));
224
183
intrinsicOperands.push_back (adaptor.getVal ());
0 commit comments