Skip to content

Commit

Permalink
[BACKEND] Fix getElemsPerThread for mmav3 dot operand (#5189)
Browse files Browse the repository at this point in the history
In mmav3 case the number of elements per threads should be independent
of the element type, we should only consider kWidth.
TODO: it should also be true for MMAv2 but the logic is a bit more
complicated.

Also enable larger block_m in mixed mode tests to exercise MMAv3 case
  • Loading branch information
ThomasRaoux authored Nov 19, 2024
1 parent a6bb57d commit aaf64d6
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 15 deletions.
3 changes: 2 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
bool isHopper() const;

SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;
int bitwidth, int kWidth,
int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

bool supportReduction() const {
Expand Down
10 changes: 7 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.isAmpere() || mma.isHopper()) {
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
auto rep = mma.getRepForOperand(shape, bitwidth, idx);
auto rep = mma.getRepForOperand(shape, bitwidth, kWidth, idx);
auto sizePerThread = getSizePerThread();
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
if (rank == 3)
Expand Down Expand Up @@ -1974,14 +1974,18 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {

SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int opIdx) const {
int kWidth, int opIdx) const {
auto rank = shape.size();
auto warpsPerCTA = getWarpsPerCTA();

// {batch, m, n, k}
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
// TODO: rep per operand is not accurate for Hopper. It is currently done that
// way to allow us to get the correct total number of elements. this will be
// fixed when moving to linear layout.
SmallVector<int> shapePerWarp = {
1, 16, 8, isHopper() ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
int numRepBatch =
rank == 3
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
Expand Down
13 changes: 8 additions & 5 deletions python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import triton
import triton.language as tl
from triton._internal_testing import is_hip_mi300, is_cuda
from triton._internal_testing import is_hip_mi300, is_cuda, is_hip

input_dtypes = ["float16", "float32", "float64"]
if is_cuda():
Expand Down Expand Up @@ -77,16 +77,19 @@ def matmul_kernel(A, B, C, M, N, K, #
tl.store(C, acc, mask=mask)


@pytest.mark.parametrize("M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype",
[(M, K, N, BLOCK_K, w, x, o) #
@pytest.mark.parametrize("M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype",
[(M, K, N, BLOCK_K, BLOCK_M, w, x, o) #
for BLOCK_K in [16, 32] #
for BLOCK_M in [16, 64] #
for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] #
for w in input_dtypes
for x in input_dtypes #
for o in out_dtypes])
def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype):
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype):
if x_dtype == w_dtype:
pytest.skip("skip the same input dtype")
if is_hip() and BLOCK_M == 64 and w_dtype in ["float8_e5m2", "float8_e4m3fnuz"]:
pytest.skip("skip due to bug on HIP path")
device = torch.cuda.current_device()
x_dtype: torch.dtype = getattr(torch, x_dtype)
w_dtype: torch.dtype = getattr(torch, w_dtype)
Expand All @@ -109,7 +112,7 @@ def init_tensor(dtype, shape):
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)

# launch kernel
block_m, block_n, block_k = 16, 16, BLOCK_K
block_m, block_n, block_k = BLOCK_M, 16, BLOCK_K
grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1)

matmul_kernel[grid](
Expand Down
12 changes: 12 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return
}
}

// -----

#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_fp8_to_fp16_dot_operand
// CHECK-COUNT-16: cvt.rn.f16x2.e5m2x2
tt.func @test_fp8_to_fp16_dot_operand(%arg: tensor<128x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>) {
%r = tt.fp_to_fp %arg : tensor<128x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,8 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc,
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth;

int kWidth = encoding.getKWidth();
auto numRep =
mmaLayout.getRepForOperand(shapePerCTA, mmaBitwidth, encoding.getOpIdx());
auto numRep = mmaLayout.getRepForOperand(shapePerCTA, mmaBitwidth, kWidth,
encoding.getOpIdx());

auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto warpOrder = mmaLayout.getWarpOrder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,14 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,

int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth();
auto dotOpA = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
auto repA = cast<NvidiaMmaEncodingAttr>(dotOpA.getParent())
.getRepForOperand(aShapePerCTA, bitwidth, dotOpA.getOpIdx());
int kWidth = dotOpA.getKWidth();
auto repA =
cast<NvidiaMmaEncodingAttr>(dotOpA.getParent())
.getRepForOperand(aShapePerCTA, bitwidth, kWidth, dotOpA.getOpIdx());
auto dotOpB = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding());
auto repB = cast<NvidiaMmaEncodingAttr>(dotOpB.getParent())
.getRepForOperand(bShapePerCTA, bitwidth, dotOpB.getOpIdx());
auto repB =
cast<NvidiaMmaEncodingAttr>(dotOpB.getParent())
.getRepForOperand(bShapePerCTA, bitwidth, kWidth, dotOpB.getOpIdx());

assert(repA[2] == repB[1]);
assert(repA[0] == repB[0]);
Expand Down

0 comments on commit aaf64d6

Please sign in to comment.