Skip to content

Commit 24a2b97

Browse files
committed
Address comments
1 parent 3e827e0 commit 24a2b97

File tree

4 files changed

+55
-82
lines changed

4 files changed

+55
-82
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
6969
/// function is used to combine multiple values into a single value.
7070
Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
7171
Type dstType);
72+
73+
/// Performs the index computation to get to the element at `indices` of the
74+
/// memory pointed to by `memRefDesc`, using the layout map of `type`.
75+
/// The indices are linearized as:
76+
/// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
77+
Value getStridedElementPtr(OpBuilder &builder, Location loc,
78+
const LLVMTypeConverter &converter, MemRefType type,
79+
Value memRefDesc, ValueRange indices);
7280
} // namespace LLVM
7381

7482
/// Base class for operation conversions targeting the LLVM IR dialect. It
@@ -107,8 +115,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
107115
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
108116
Type resultType, int64_t value);
109117

110-
// This is a strided getElementPtr variant that linearizes subscripts as:
111-
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
118+
/// Convenience wrapper for the corresponding helper utility.
119+
/// This is a strided getElementPtr variant with linearized subscripts.
112120
Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc,
113121
ValueRange indices,
114122
ConversionPatternRewriter &rewriter) const;

mlir/include/mlir/Dialect/AMX/AMX.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
//
2626
//===----------------------------------------------------------------------===//
2727

28-
#ifndef AMX_OPS
29-
#define AMX_OPS
28+
#ifndef AMX
29+
#define AMX
3030

3131
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
3232
include "mlir/Dialect/AMX/AMXInterfaces.td"
@@ -371,4 +371,4 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
371371
let hasVerifier = 1;
372372
}
373373

374-
#endif // AMX_OPS
374+
#endif // AMX

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,38 +61,8 @@ Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder,
6161
Value ConvertToLLVMPattern::getStridedElementPtr(
6262
Location loc, MemRefType type, Value memRefDesc, ValueRange indices,
6363
ConversionPatternRewriter &rewriter) const {
64-
65-
auto [strides, offset] = type.getStridesAndOffset();
66-
67-
MemRefDescriptor memRefDescriptor(memRefDesc);
68-
// Use a canonical representation of the start address so that later
69-
// optimizations have a longer sequence of instructions to CSE.
70-
// If we don't do that we would sprinkle the memref.offset in various
71-
// position of the different address computations.
72-
Value base =
73-
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
74-
75-
Type indexType = getIndexType();
76-
Value index;
77-
for (int i = 0, e = indices.size(); i < e; ++i) {
78-
Value increment = indices[i];
79-
if (strides[i] != 1) { // Skip if stride is 1.
80-
Value stride =
81-
ShapedType::isDynamic(strides[i])
82-
? memRefDescriptor.stride(rewriter, loc, i)
83-
: createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
84-
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
85-
}
86-
index =
87-
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
88-
}
89-
90-
Type elementPtrType = memRefDescriptor.getElementPtrType();
91-
return index ? rewriter.create<LLVM::GEPOp>(
92-
loc, elementPtrType,
93-
getTypeConverter()->convertType(type.getElementType()),
94-
base, index)
95-
: base;
64+
return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type,
65+
memRefDesc, indices);
9666
}
9767

9868
// Check if the MemRefType `type` is supported by the lowering. We currently
@@ -512,3 +482,39 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
512482

513483
return res;
514484
}
485+
486+
Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
487+
const LLVMTypeConverter &converter,
488+
MemRefType type, Value memRefDesc,
489+
ValueRange indices) {
490+
auto [strides, offset] = type.getStridesAndOffset();
491+
492+
MemRefDescriptor memRefDescriptor(memRefDesc);
493+
// Use a canonical representation of the start address so that later
494+
// optimizations have a longer sequence of instructions to CSE.
495+
// If we don't do that we would sprinkle the memref.offset in various
496+
// position of the different address computations.
497+
Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type);
498+
499+
Type indexType = converter.getIndexType();
500+
Value index;
501+
for (int i = 0, e = indices.size(); i < e; ++i) {
502+
Value increment = indices[i];
503+
if (strides[i] != 1) { // Skip if stride is 1.
504+
Value stride =
505+
ShapedType::isDynamic(strides[i])
506+
? memRefDescriptor.stride(builder, loc, i)
507+
: builder.create<LLVM::ConstantOp>(
508+
loc, indexType, builder.getIndexAttr(strides[i]));
509+
increment = builder.create<LLVM::MulOp>(loc, increment, stride);
510+
}
511+
index =
512+
index ? builder.create<LLVM::AddOp>(loc, index, increment) : increment;
513+
}
514+
515+
Type elementPtrType = memRefDescriptor.getElementPtrType();
516+
return index ? builder.create<LLVM::GEPOp>(
517+
loc, elementPtrType,
518+
converter.convertType(type.getElementType()), base, index)
519+
: base;
520+
}

mlir/lib/Dialect/AMX/IR/AMXDialect.cpp

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -64,46 +64,6 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
6464
return success();
6565
}
6666

67-
/// Get pointer to a memref descriptor.
68-
/// Optionally, the base pointer can be offset using linearized index computed
69-
/// from the given indices.
70-
static Value getBufferPtr(Location loc, MemRefType type, Value buffer,
71-
ValueRange indices,
72-
const LLVMTypeConverter &typeConverter,
73-
RewriterBase &rewriter) {
74-
auto [strides, offset] = type.getStridesAndOffset();
75-
76-
MemRefDescriptor memRefDescriptor(buffer);
77-
Value base = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
78-
79-
int numIndices = indices.size();
80-
if (numIndices == 0)
81-
return base;
82-
83-
assert(type.getRank() == numIndices &&
84-
"expects number of indices equal to memref rank");
85-
Value index;
86-
Type indexType = typeConverter.getIndexType();
87-
for (int i = 0; i < numIndices; ++i) {
88-
Value increment = indices[i];
89-
if (strides[i] != 1) { // Skip if stride is 1.
90-
Value stride =
91-
ShapedType::isDynamic(strides[i])
92-
? memRefDescriptor.stride(rewriter, loc, i)
93-
: rewriter.create<LLVM::ConstantOp>(
94-
loc, indexType, rewriter.getIndexAttr(strides[i]));
95-
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
96-
}
97-
index =
98-
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
99-
}
100-
101-
Type elementPtrType = memRefDescriptor.getElementPtrType();
102-
return rewriter.create<LLVM::GEPOp>(
103-
loc, elementPtrType, typeConverter.convertType(type.getElementType()),
104-
base, index);
105-
}
106-
10767
/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
10868
/// dimension directly translates into the number of rows of the tiles.
10969
/// The second dimensions needs to be scaled by the number of bytes.
@@ -122,7 +82,6 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
12282

12383
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
12484
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
125-
/// Returns failure if proper stride couldn't be found.
12685
static Value getStride(Location loc, MemRefType mType, Value base,
12786
RewriterBase &rewriter) {
12887
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
@@ -184,8 +143,8 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
184143
SmallVector<Value> intrinsicOperands;
185144
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
186145
intrinsicOperands.push_back(
187-
getBufferPtr(loc, getMemRefType(), adaptor.getBase(),
188-
adaptor.getIndices(), typeConverter, rewriter));
146+
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
147+
adaptor.getBase(), adaptor.getIndices()));
189148
intrinsicOperands.push_back(
190149
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
191150

@@ -217,8 +176,8 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
217176
SmallVector<Value> intrinsicOperands;
218177
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
219178
intrinsicOperands.push_back(
220-
getBufferPtr(loc, getMemRefType(), adaptor.getBase(),
221-
adaptor.getIndices(), typeConverter, rewriter));
179+
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
180+
adaptor.getBase(), adaptor.getIndices()));
222181
intrinsicOperands.push_back(
223182
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
224183
intrinsicOperands.push_back(adaptor.getVal());

0 commit comments

Comments
 (0)