Skip to content

Commit

Permalink
Simplify Ampere/Hopper paths introduced in #5189
Browse files Browse the repository at this point in the history
The tile shapes in Ampere and Hopper are the same, and as such, we can
reuse the logic.

More generally, we should try to minimise the calls to `isAmpere` and
`isHopper` throughout the codebase. I'll do a pass fixing many of these
once we land LLs for `ldmatrix` and Hopper.
  • Loading branch information
lezcano committed Nov 20, 2024
1 parent aaf64d6 commit 6bd85c0
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 18 deletions.
3 changes: 1 addition & 2 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1129,8 +1129,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
bool isHopper() const;

SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int kWidth,
int opIdx) const;
int bitwidth, int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

bool supportReduction() const {
Expand Down
12 changes: 4 additions & 8 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,9 +940,9 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.isAmpere() || mma.isHopper()) {
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
auto rep = mma.getRepForOperand(shape, bitwidth, kWidth, idx);
auto rep = mma.getRepForOperand(shape, bitwidth, idx);
auto sizePerThread = getSizePerThread();
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
auto elemsPerKRep = 32 / bitwidth * 2;
if (rank == 3)
elemsPerThread[0] = rep[0];
elemsPerThread[rank - 2] =
Expand Down Expand Up @@ -1974,18 +1974,14 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {

SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int kWidth, int opIdx) const {
int opIdx) const {
auto rank = shape.size();
auto warpsPerCTA = getWarpsPerCTA();

// {batch, m, n, k}
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
// TODO: rep per operand is not accurate for Hopper. It is currently done that
// way to allow us to get the correct total number of elements. this will be
// fixed when moving to linear layout.
SmallVector<int> shapePerWarp = {
1, 16, 8, isHopper() ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
int numRepBatch =
rank == 3
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,8 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc,
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth;

int kWidth = encoding.getKWidth();
auto numRep = mmaLayout.getRepForOperand(shapePerCTA, mmaBitwidth, kWidth,
encoding.getOpIdx());
auto numRep =
mmaLayout.getRepForOperand(shapePerCTA, mmaBitwidth, encoding.getOpIdx());

auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto warpOrder = mmaLayout.getWarpOrder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,13 +481,11 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,
int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth();
auto dotOpA = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
int kWidth = dotOpA.getKWidth();
auto repA =
cast<NvidiaMmaEncodingAttr>(dotOpA.getParent())
.getRepForOperand(aShapePerCTA, bitwidth, kWidth, dotOpA.getOpIdx());
auto repA = cast<NvidiaMmaEncodingAttr>(dotOpA.getParent())
.getRepForOperand(aShapePerCTA, bitwidth, dotOpA.getOpIdx());
auto dotOpB = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding());
auto repB =
cast<NvidiaMmaEncodingAttr>(dotOpB.getParent())
.getRepForOperand(bShapePerCTA, bitwidth, kWidth, dotOpB.getOpIdx());
auto repB = cast<NvidiaMmaEncodingAttr>(dotOpB.getParent())
.getRepForOperand(bShapePerCTA, bitwidth, dotOpB.getOpIdx());

assert(repA[2] == repB[1]);
assert(repA[0] == repB[0]);
Expand Down

0 comments on commit 6bd85c0

Please sign in to comment.