@@ -206,7 +206,6 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale_shared, k):
206206 B_local = T .alloc_fragment (B_shared_shape , storage_dtype )
207207 B_dequantize_local = T .alloc_fragment (B_dequantize_shared_shape , out_dtype )
208208
209- bx = T .get_block_binding (0 ) # noqa: F841
210209 T .copy (B_shared , B_local )
211210 for i , j in T .Parallel (block_N , block_K ):
212211 B_dequantize_local [i , j ] = _tir_u8_to_f4_to_bf16 (
@@ -244,7 +243,7 @@ def main(
244243 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
245244 C_shared = T .alloc_shared ((block_M , block_N ), out_dtype )
246245 topk_weights_shared = T .alloc_shared ((block_M ), out_dtype )
247- sorted_token_ids_shared = T .alloc_shared ((block_M ), "int32" ) # todo: frag?
246+ sorted_token_ids_shared = T .alloc_shared ((block_M ), "int32" )
248247 expert_id = T .alloc_local ((1 ), "int32" ) # the expert id for the current block
249248 # To use 1D TMA, the last dim of Scale_shared must have stride=1
250249 # May use much more shared memory than necessary
@@ -462,4 +461,4 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
462461 scale_size = 32
463462 topk = 4
464463 E = 32
465- main (M , N , K , scale_size , fast_dequant = True , with_bias = True , topk = topk , E = E )
464+ main (M , N , K , scale_size , fast_dequant = True , with_bias = True , topk = topk , E = E )
0 commit comments