Skip to content

Commit e90d2f2

Browse files
committed
Address comments
1 parent f9ca93b commit e90d2f2

File tree

4 files changed

+58
-82
lines changed

4 files changed

+58
-82
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
6969
/// function is used to combine multiple values into a single value.
7070
Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
7171
Type dstType);
72+
73+
/// Performs the index computation to get to the element at `indices` of the
74+
/// memory pointed to by `memRefDesc`, using the layout map of `type`.
75+
/// The indices are linearized as:
76+
/// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
77+
Value getStridedElementPtr(OpBuilder &builder, Location loc,
78+
const LLVMTypeConverter &converter, MemRefType type,
79+
Value memRefDesc, ValueRange indices,
80+
LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none);
7281
} // namespace LLVM
7382

7483
/// Base class for operation conversions targeting the LLVM IR dialect. It
@@ -107,8 +116,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
107116
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
108117
Type resultType, int64_t value);
109118

110-
// This is a strided getElementPtr variant that linearizes subscripts as:
111-
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
119+
/// Convenience wrapper for the corresponding helper utility.
120+
/// This is a strided getElementPtr variant with linearized subscripts.
112121
Value getStridedElementPtr(
113122
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
114123
Value memRefDesc, ValueRange indices,

mlir/include/mlir/Dialect/AMX/AMX.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
//
2626
//===----------------------------------------------------------------------===//
2727

28-
#ifndef AMX_OPS
29-
#define AMX_OPS
28+
#ifndef AMX
29+
#define AMX
3030

3131
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
3232
include "mlir/Dialect/AMX/AMXInterfaces.td"
@@ -371,4 +371,4 @@ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
371371
let hasVerifier = 1;
372372
}
373373

374-
#endif // AMX_OPS
374+
#endif // AMX

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -62,38 +62,8 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
6262
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
6363
Value memRefDesc, ValueRange indices,
6464
LLVM::GEPNoWrapFlags noWrapFlags) const {
65-
66-
auto [strides, offset] = type.getStridesAndOffset();
67-
68-
MemRefDescriptor memRefDescriptor(memRefDesc);
69-
// Use a canonical representation of the start address so that later
70-
// optimizations have a longer sequence of instructions to CSE.
71-
// If we don't do that we would sprinkle the memref.offset in various
72-
// position of the different address computations.
73-
Value base =
74-
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
75-
76-
Type indexType = getIndexType();
77-
Value index;
78-
for (int i = 0, e = indices.size(); i < e; ++i) {
79-
Value increment = indices[i];
80-
if (strides[i] != 1) { // Skip if stride is 1.
81-
Value stride =
82-
ShapedType::isDynamic(strides[i])
83-
? memRefDescriptor.stride(rewriter, loc, i)
84-
: createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
85-
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
86-
}
87-
index =
88-
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
89-
}
90-
91-
Type elementPtrType = memRefDescriptor.getElementPtrType();
92-
return index ? rewriter.create<LLVM::GEPOp>(
93-
loc, elementPtrType,
94-
getTypeConverter()->convertType(type.getElementType()),
95-
base, index, noWrapFlags)
96-
: base;
65+
return LLVM::getStridedElementPtr(rewriter, loc, *getTypeConverter(), type,
66+
memRefDesc, indices, noWrapFlags);
9767
}
9868

9969
// Check if the MemRefType `type` is supported by the lowering. We currently
@@ -513,3 +483,41 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
513483

514484
return res;
515485
}
486+
487+
Value mlir::LLVM::getStridedElementPtr(OpBuilder &builder, Location loc,
488+
const LLVMTypeConverter &converter,
489+
MemRefType type, Value memRefDesc,
490+
ValueRange indices,
491+
LLVM::GEPNoWrapFlags noWrapFlags) {
492+
auto [strides, offset] = type.getStridesAndOffset();
493+
494+
MemRefDescriptor memRefDescriptor(memRefDesc);
495+
// Use a canonical representation of the start address so that later
496+
// optimizations have a longer sequence of instructions to CSE.
497+
// If we don't do that we would sprinkle the memref.offset in various
498+
// position of the different address computations.
499+
Value base = memRefDescriptor.bufferPtr(builder, loc, converter, type);
500+
501+
Type indexType = converter.getIndexType();
502+
Value index;
503+
for (int i = 0, e = indices.size(); i < e; ++i) {
504+
Value increment = indices[i];
505+
if (strides[i] != 1) { // Skip if stride is 1.
506+
Value stride =
507+
ShapedType::isDynamic(strides[i])
508+
? memRefDescriptor.stride(builder, loc, i)
509+
: builder.create<LLVM::ConstantOp>(
510+
loc, indexType, builder.getIndexAttr(strides[i]));
511+
increment = builder.create<LLVM::MulOp>(loc, increment, stride);
512+
}
513+
index =
514+
index ? builder.create<LLVM::AddOp>(loc, index, increment) : increment;
515+
}
516+
517+
Type elementPtrType = memRefDescriptor.getElementPtrType();
518+
return index ? builder.create<LLVM::GEPOp>(
519+
loc, elementPtrType,
520+
converter.convertType(type.getElementType()), base, index,
521+
noWrapFlags)
522+
: base;
523+
}

mlir/lib/Dialect/AMX/IR/AMXDialect.cpp

Lines changed: 4 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -64,46 +64,6 @@ static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
6464
return success();
6565
}
6666

67-
/// Get pointer to a memref descriptor.
68-
/// Optionally, the base pointer can be offset using linearized index computed
69-
/// from the given indices.
70-
static Value getBufferPtr(Location loc, MemRefType type, Value buffer,
71-
ValueRange indices,
72-
const LLVMTypeConverter &typeConverter,
73-
RewriterBase &rewriter) {
74-
auto [strides, offset] = type.getStridesAndOffset();
75-
76-
MemRefDescriptor memRefDescriptor(buffer);
77-
Value base = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
78-
79-
int numIndices = indices.size();
80-
if (numIndices == 0)
81-
return base;
82-
83-
assert(type.getRank() == numIndices &&
84-
"expects number of indices equal to memref rank");
85-
Value index;
86-
Type indexType = typeConverter.getIndexType();
87-
for (int i = 0; i < numIndices; ++i) {
88-
Value increment = indices[i];
89-
if (strides[i] != 1) { // Skip if stride is 1.
90-
Value stride =
91-
ShapedType::isDynamic(strides[i])
92-
? memRefDescriptor.stride(rewriter, loc, i)
93-
: rewriter.create<LLVM::ConstantOp>(
94-
loc, indexType, rewriter.getIndexAttr(strides[i]));
95-
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
96-
}
97-
index =
98-
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
99-
}
100-
101-
Type elementPtrType = memRefDescriptor.getElementPtrType();
102-
return rewriter.create<LLVM::GEPOp>(
103-
loc, elementPtrType, typeConverter.convertType(type.getElementType()),
104-
base, index);
105-
}
106-
10767
/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first
10868
/// dimension directly translates into the number of rows of the tiles.
10969
/// The second dimensions needs to be scaled by the number of bytes.
@@ -122,7 +82,6 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType,
12282

12383
/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
12484
/// shape may "envelop" the actual tile shape, and may be dynamically sized.
125-
/// Returns failure if proper stride couldn't be found.
12685
static Value getStride(Location loc, MemRefType mType, Value base,
12786
RewriterBase &rewriter) {
12887
assert(mType.getRank() >= 2 && "Invalid shape for AMX strides");
@@ -184,8 +143,8 @@ amx::TileLoadOp::getIntrinsicOperands(ArrayRef<Value> operands,
184143
SmallVector<Value> intrinsicOperands;
185144
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
186145
intrinsicOperands.push_back(
187-
getBufferPtr(loc, getMemRefType(), adaptor.getBase(),
188-
adaptor.getIndices(), typeConverter, rewriter));
146+
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
147+
adaptor.getBase(), adaptor.getIndices()));
189148
intrinsicOperands.push_back(
190149
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
191150

@@ -217,8 +176,8 @@ amx::TileStoreOp::getIntrinsicOperands(ArrayRef<Value> operands,
217176
SmallVector<Value> intrinsicOperands;
218177
intrinsicOperands.append(getTileSizes(loc, getTileType(), rewriter));
219178
intrinsicOperands.push_back(
220-
getBufferPtr(loc, getMemRefType(), adaptor.getBase(),
221-
adaptor.getIndices(), typeConverter, rewriter));
179+
LLVM::getStridedElementPtr(rewriter, loc, typeConverter, getMemRefType(),
180+
adaptor.getBase(), adaptor.getIndices()));
222181
intrinsicOperands.push_back(
223182
getStride(loc, getMemRefType(), adaptor.getBase(), rewriter));
224183
intrinsicOperands.push_back(adaptor.getVal());

0 commit comments

Comments
 (0)