-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND] Fix getElemsPerThread for mmav3 dot operand #5189
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,7 @@ | |
|
||
import triton | ||
import triton.language as tl | ||
from triton._internal_testing import is_hip_mi300, is_cuda | ||
from triton._internal_testing import is_hip_mi300, is_cuda, is_hip | ||
|
||
input_dtypes = ["float16", "float32", "float64"] | ||
if is_cuda(): | ||
|
@@ -77,16 +77,19 @@ def matmul_kernel(A, B, C, M, N, K, # | |
tl.store(C, acc, mask=mask) | ||
|
||
|
||
@pytest.mark.parametrize("M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype", | ||
[(M, K, N, BLOCK_K, w, x, o) # | ||
@pytest.mark.parametrize("M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype", | ||
[(M, K, N, BLOCK_K, BLOCK_M, w, x, o) # | ||
for BLOCK_K in [16, 32] # | ||
for BLOCK_M in [16, 64] # | ||
for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] # | ||
for w in input_dtypes | ||
for x in input_dtypes # | ||
for o in out_dtypes]) | ||
def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype): | ||
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype): | ||
if x_dtype == w_dtype: | ||
pytest.skip("skip the same input dtype") | ||
if is_hip() and BLOCK_M == 64 and w_dtype in ["float8_e5m2", "float8_e4m3fnuz"]: | ||
pytest.skip("skip due to bug on HIP path") | ||
Comment on lines
+91
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: @antiagainst There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ThomasRaoux I'll try to repro tomorrow but do you remember what the bug/error was here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! It was a mismatch (miscompile). If you comment out this you should see the failures |
||
device = torch.cuda.current_device() | ||
x_dtype: torch.dtype = getattr(torch, x_dtype) | ||
w_dtype: torch.dtype = getattr(torch, w_dtype) | ||
|
@@ -109,7 +112,7 @@ def init_tensor(dtype, shape): | |
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype) | ||
|
||
# launch kernel | ||
block_m, block_n, block_k = 16, 16, BLOCK_K | ||
block_m, block_n, block_k = BLOCK_M, 16, BLOCK_K | ||
grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1) | ||
|
||
matmul_kernel[grid]( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is correct for Hopper. The number of reps does not depend on
kWidth
, askWidth
is part of the size, and thus the tile.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rep
specifies how many mma/wgmma instructions should be usedThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well the idea is that for mmav3 the mma operation will have to take a dot_operand where kWidth and bitwidth match, so we want the layout to match this way. Like I said I think kWidth the wrong parameter here but that's why it feels like number of reps should depend on kWidth.
I'm happy to fix it a different way if you have a different suggestion. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this a bit more. I thought the number of repeats = total size / tile size for each dimension. Indeed this is what the function looks like it's computing to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is indeed that, I miswrote. For reference, see
triton/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp
Lines 552 to 556 in 5acbbb7
where
callMma
issues one mma instruction.Then, the rest of the computation of
getElemesPerThread
for Ampere just computes how many elements are there in theK
dimension and multiplies by it (perhaps broadcasting the layout). I think we should just be able to share the logic from Ampere for Hopper here, as the layouts are equivalent when it comes to these computations.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant does not depend on bitwidth