Skip to content

Commit f51f7bf

Browse files
committed
lint
1 parent 93a8b60 commit f51f7bf

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

examples/dequantize_gemm/example_dequant_groupgemm_bf16_mxfp4_hopper.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,6 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale_shared,
193193
B_dequantize_shared[index // block_K,
194194
index % block_K] = B_dequantize_local_thread[v]
195195

196-
197196
return fast_dequant_bf16_fp4_twiddling
198197

199198
def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
@@ -260,8 +259,7 @@ def main(
260259
if threads == 512:
261260
T.disable_warp_group_reg_alloc()
262261

263-
T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M],
264-
sorted_token_ids_shared)
262+
T.copy(sorted_token_ids[by * block_M:(by + 1) * block_M], sorted_token_ids_shared)
265263
expert_id[0] = expert_ids[by]
266264

267265
# Get the topk weights of each token in the current block
@@ -287,7 +285,8 @@ def main(
287285
if sorted_token_ids_shared[i] != -1:
288286
A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j]
289287
if fast_dequant:
290-
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared, k)
288+
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared, Scale_shared,
289+
k)
291290
else:
292291
get_simple_dequant_func()(B_shared, B_dequantize_shared, Scale_shared, k)
293292

@@ -300,7 +299,7 @@ def main(
300299
for i, j in T.Parallel(block_M, block_N):
301300
if sorted_token_ids_shared[i] != -1:
302301
C[sorted_token_ids_shared[i] // topk, sorted_token_ids_shared[i] % topk,
303-
bx * block_N + j] = C_shared[i, j]
302+
bx * block_N + j] = C_shared[i, j]
304303

305304
return main
306305

@@ -397,20 +396,13 @@ def get_data(m, n, k, qk, scale_size, topk, E, block_M):
397396
return A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M
398397

399398

400-
def main(m=256,
401-
n=256,
402-
k=256,
403-
scale_size=32,
404-
fast_dequant=True,
405-
with_bias=False,
406-
topk=4,
407-
E=32):
399+
def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, topk=4, E=32):
408400
# Tunable parameters
409401
block_M, block_N, block_K = 128, 128, 256
410402
num_stages = 2
411403
threads = 512
412404
split = 1
413-
405+
414406
total_flops = 2 * m * n * k
415407
num_bits = 4
416408
num_elems_per_byte = 8 // num_bits
@@ -453,7 +445,8 @@ def main(m=256,
453445
A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, block_M=block_M)
454446

455447
print("All checks pass. ✅")
456-
latency = tilelang.profiler.do_bench(lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=500)
448+
latency = tilelang.profiler.do_bench(
449+
lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=500)
457450
print("Tile-lang: {:.2f} ms".format(latency))
458451
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
459452

@@ -463,7 +456,7 @@ def main(m=256,
463456
print(f"max abs diff: {max_val} at index: {max_idx}")
464457
assert_similar(output, ref_output, name="output", eps=1e-5)
465458

466-
459+
467460
if __name__ == "__main__":
468461
M, N, K = 1024, 2944, 3072 # From gpt-oss-20b
469462
scale_size = 32

0 commit comments

Comments
 (0)