Skip to content

Commit 6eecc9d

Browse files
committed
rename and add test
1 parent f51f7bf commit 6eecc9d

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

examples/dequantize_gemm/example_dequant_groupgemm_bf16_mxfp4_hopper.py renamed to examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
206206
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
207207
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
208208

209-
bx = T.get_block_binding(0) # noqa: F841
210209
T.copy(B_shared, B_local)
211210
for i, j in T.Parallel(block_N, block_K):
212211
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
@@ -244,7 +243,7 @@ def main(
244243
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
245244
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
246245
topk_weights_shared = T.alloc_shared((block_M), out_dtype)
247-
sorted_token_ids_shared = T.alloc_shared((block_M), "int32") # todo: frag?
246+
sorted_token_ids_shared = T.alloc_shared((block_M), "int32")
248247
expert_id = T.alloc_local((1), "int32") # the expert id for the current block
249248
# To use 1D TMA, the last dim of Scale_shared must have stride=1
250249
# May use much more shared memory than necessary
@@ -462,4 +461,4 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
462461
scale_size = 32
463462
topk = 4
464463
E = 32
465-
main(M, N, K, scale_size, fast_dequant=True, with_bias=True, topk=topk, E=E)
464+
main(M, N, K, scale_size, fast_dequant=True, with_bias=True, topk=topk, E=E)

examples/dequantize_gemm/test_example_dequantize_gemm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import example_dequant_gemm_fp4_hopper
55
import example_dequant_gemm_bf16_mxfp4_hopper
66
import example_dequant_gemm_bf16_mxfp4_hopper_tma
7+
import example_dequant_groupedgemm_bf16_mxfp4_hopper
78

89

910
@tilelang.testing.requires_cuda
@@ -29,5 +30,11 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
2930
example_dequant_gemm_bf16_mxfp4_hopper_tma.main()
3031

3132

33+
@tilelang.testing.requires_cuda
34+
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
35+
def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
36+
example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
37+
38+
3239
if __name__ == "__main__":
3340
tilelang.testing.main()

0 commit comments

Comments
 (0)