Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Fix getElemsPerThread for mmav3 dot operand #5189

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
bool isHopper() const;

SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;
int bitwidth, int kWidth,
int opIdx) const;
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;

bool supportReduction() const {
Expand Down
7 changes: 4 additions & 3 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (mma.isAmpere() || mma.isHopper()) {
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
auto rep = mma.getRepForOperand(shape, bitwidth, idx);
auto rep = mma.getRepForOperand(shape, bitwidth, kWidth, idx);
auto sizePerThread = getSizePerThread();
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
if (rank == 3)
Expand Down Expand Up @@ -1974,14 +1974,15 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {

SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int opIdx) const {
int kWidth, int opIdx) const {
auto rank = shape.size();
auto warpsPerCTA = getWarpsPerCTA();

// {batch, m, n, k}
// Hopper path never uses the n value, since this method is only invoked
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
SmallVector<int> shapePerWarp = {
1, 16, 8, isHopper() ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
Comment on lines +1987 to +1988
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct for Hopper. The number of reps does not depend on kWidth, as kWidth is part of the size, and thus the tile.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct for Hopper

Right

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rep specifies how many mma/wgmma instructions should be used

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well the idea is that for mmav3 the mma operation will have to take a dot_operand where kWidth and bitwidth match, so we want the layout to match this way. Like I said I think kWidth the wrong parameter here but that's why it feels like number of reps should depend on kWidth.
I'm happy to fix it a different way if you have a different suggestion. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct for Hopper. The number of reps does not depend on kWidth, as kWidth is part of the size, and thus the tile.

Can you explain this a bit more. I thought the number of repeats = total size / tile size for each dimension. Indeed this is what the function looks like it's computing to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is indeed that, I miswrote. For reference, see

for (int b = 0; b < repBatch; ++b)
for (int k = 0; k < repK; ++k)
for (int m = 0; m < repM; ++m)
for (int n = 0; n < repN; ++n)
callMma(b, 2 * m, n, 2 * k);

where callMma issues one mma instruction.

Then, the rest of the computation of getElemesPerThread for Ampere just computes how many elements are there in the K dimension and multiplies by it (perhaps broadcasting the layout). I think we should just be able to share the logic from Ampere for Hopper here, as the layouts are equivalent when it comes to these computations.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct for Hopper. The number of reps does not depend on kWidth, as kWidth is part of the size, and thus the tile.

Can you explain this a bit more. I thought the number of repeats = total size / tile size for each dimension. Indeed this is what the function looks like it's computing to me.

I meant does not depend on bitwidth

int numRepBatch =
rank == 3
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
Expand Down
13 changes: 8 additions & 5 deletions python/test/regression/test_cast_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import triton
import triton.language as tl
from triton._internal_testing import is_hip_mi300, is_cuda
from triton._internal_testing import is_hip_mi300, is_cuda, is_hip

input_dtypes = ["float16", "float32", "float64"]
if is_cuda():
Expand Down Expand Up @@ -77,16 +77,19 @@ def matmul_kernel(A, B, C, M, N, K, #
tl.store(C, acc, mask=mask)


@pytest.mark.parametrize("M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype",
[(M, K, N, BLOCK_K, w, x, o) #
@pytest.mark.parametrize("M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype",
[(M, K, N, BLOCK_K, BLOCK_M, w, x, o) #
for BLOCK_K in [16, 32] #
for BLOCK_M in [16, 64] #
for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] #
for w in input_dtypes
for x in input_dtypes #
for o in out_dtypes])
def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype):
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype):
if x_dtype == w_dtype:
pytest.skip("skip the same input dtype")
if is_hip() and BLOCK_M == 64 and w_dtype in ["float8_e5m2", "float8_e4m3fnuz"]:
pytest.skip("skip due to bug on HIP path")
Comment on lines +91 to +92
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux I'll try to repro tomorrow but do you remember what the bug/error was here?

Copy link
Collaborator Author

@ThomasRaoux ThomasRaoux Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It was a mismatch (miscompile). If you comment out this you should see the failures

device = torch.cuda.current_device()
x_dtype: torch.dtype = getattr(torch, x_dtype)
w_dtype: torch.dtype = getattr(torch, w_dtype)
Expand All @@ -109,7 +112,7 @@ def init_tensor(dtype, shape):
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)

# launch kernel
block_m, block_n, block_k = 16, 16, BLOCK_K
block_m, block_n, block_k = BLOCK_M, 16, BLOCK_K
grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1)

matmul_kernel[grid](
Expand Down
12 changes: 12 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
tt.return
}
}

// -----

#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: test_fp8_to_fp16_dot_operand
// CHECK-COUNT-16: cvt.rn.f16x2.e5m2x2
tt.func @test_fp8_to_fp16_dot_operand(%arg: tensor<128x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>) {
%r = tt.fp_to_fp %arg : tensor<128x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -667,8 +667,8 @@ Value loadArg(ConversionPatternRewriter &rewriter, Location loc,
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / mmaBitwidth;

int kWidth = encoding.getKWidth();
auto numRep =
mmaLayout.getRepForOperand(shapePerCTA, mmaBitwidth, encoding.getOpIdx());
auto numRep = mmaLayout.getRepForOperand(shapePerCTA, mmaBitwidth, kWidth,
encoding.getOpIdx());

auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
auto warpOrder = mmaLayout.getWarpOrder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,14 @@ LogicalResult convertDot(const LLVMTypeConverter *typeConverter,

int bitwidth = aTensorTy.getElementType().getIntOrFloatBitWidth();
auto dotOpA = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
auto repA = cast<NvidiaMmaEncodingAttr>(dotOpA.getParent())
.getRepForOperand(aShapePerCTA, bitwidth, dotOpA.getOpIdx());
int kWidth = dotOpA.getKWidth();
auto repA =
cast<NvidiaMmaEncodingAttr>(dotOpA.getParent())
.getRepForOperand(aShapePerCTA, bitwidth, kWidth, dotOpA.getOpIdx());
auto dotOpB = cast<DotOperandEncodingAttr>(bTensorTy.getEncoding());
auto repB = cast<NvidiaMmaEncodingAttr>(dotOpB.getParent())
.getRepForOperand(bShapePerCTA, bitwidth, dotOpB.getOpIdx());
auto repB =
cast<NvidiaMmaEncodingAttr>(dotOpB.getParent())
.getRepForOperand(bShapePerCTA, bitwidth, kWidth, dotOpB.getOpIdx());

assert(repA[2] == repB[1]);
assert(repA[0] == repB[0]);
Expand Down
Loading