5050from vllm .utils import (
5151 has_triton_kernels ,
5252 is_torch_equal_or_newer ,
53- next_power_of_2 ,
5453 round_up ,
5554)
5655from vllm .utils .flashinfer import has_flashinfer
@@ -97,12 +96,6 @@ def get_mxfp4_backend():
9796 and has_flashinfer ()
9897 and envs .VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
9998 ):
100- logger .info_once (
101- "Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
102- "for high concurrency throughput workloads consider setting "
103- "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
104- "performance"
105- )
10699 return Mxfp4Backend .SM100_FI_MXFP4_MXFP8_TRTLLM
107100 elif current_platform .is_device_capability (100 ) and has_flashinfer ():
108101 logger .info_once (
@@ -357,7 +350,7 @@ def process_weights_after_loading(self, layer):
357350 or self .mxfp4_backend == Mxfp4Backend .SM100_FI_MXFP4_BF16
358351 ):
359352 from flashinfer .fp4_quantization import nvfp4_block_scale_interleave
360- from flashinfer .fused_moe .core import _maybe_get_cached_w2_permute_indices
353+ from flashinfer .fused_moe .core import get_w2_permute_indices_with_cache
361354
362355 layer .gemm1_alpha = Parameter (
363356 torch .tensor ([1.702 ] * self .num_experts , dtype = torch .float32 ).cuda (),
@@ -449,7 +442,7 @@ def swap_every_two_rows(x, axis=-1):
449442 epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
450443 for i in range (self .num_experts ):
451444 # w13 weight shuffling
452- permute_indices = _maybe_get_cached_w2_permute_indices (
445+ permute_indices = get_w2_permute_indices_with_cache (
453446 self ._cache_permute_indices ,
454447 w13_weight [i ].view (torch .uint8 ),
455448 epilogue_tile_m ,
@@ -460,7 +453,7 @@ def swap_every_two_rows(x, axis=-1):
460453 .contiguous ()
461454 )
462455 # w13 scale shuffling
463- permute_sf_indices = _maybe_get_cached_w2_permute_indices (
456+ permute_sf_indices = get_w2_permute_indices_with_cache (
464457 self ._cache_permute_indices ,
465458 w13_weight_scale [i ].view (torch .uint8 ),
466459 epilogue_tile_m ,
@@ -476,7 +469,7 @@ def swap_every_two_rows(x, axis=-1):
476469 )
477470 )
478471 # w13 bias shuffling
479- permute_bias_indices = _maybe_get_cached_w2_permute_indices (
472+ permute_bias_indices = get_w2_permute_indices_with_cache (
480473 self ._cache_permute_indices ,
481474 w13_bias [i ].clone ().reshape (- 1 , 1 ),
482475 epilogue_tile_m ,
@@ -488,7 +481,7 @@ def swap_every_two_rows(x, axis=-1):
488481 .contiguous ()
489482 )
490483 # w2 weight shuffling
491- permute_indices = _maybe_get_cached_w2_permute_indices (
484+ permute_indices = get_w2_permute_indices_with_cache (
492485 self ._cache_permute_indices ,
493486 w2_weight [i ].view (torch .uint8 ),
494487 epilogue_tile_m ,
@@ -499,7 +492,7 @@ def swap_every_two_rows(x, axis=-1):
499492 .contiguous ()
500493 )
501494 # w2 scale shuffling
502- permute_sf_indices = _maybe_get_cached_w2_permute_indices (
495+ permute_sf_indices = get_w2_permute_indices_with_cache (
503496 self ._cache_permute_indices ,
504497 w2_weight_scale [i ].view (torch .uint8 ),
505498 epilogue_tile_m ,
@@ -515,7 +508,7 @@ def swap_every_two_rows(x, axis=-1):
515508 )
516509 )
517510 # w2 bias shuffling
518- permute_indices = _maybe_get_cached_w2_permute_indices (
511+ permute_indices = get_w2_permute_indices_with_cache (
519512 self ._cache_permute_indices ,
520513 w2_bias [i ].clone ().reshape (- 1 , 1 ),
521514 epilogue_tile_m ,
@@ -735,30 +728,6 @@ def _interleave_mxfp4_cutlass_sm90(w):
735728 else :
736729 raise ValueError (f"Unsupported backend: { self .mxfp4_backend } " )
737730
738- def _get_tile_tokens_dim (self , x : torch .Tensor , top_k : int ):
739- # Number of tokens in the input tensor.
740- num_tokens = x .shape [0 ]
741- # Factor to account for the imbalance of the experts.
742- # factor equals to the
743- # max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
744- # - 1.0 means perfect expert distribution.
745- # - > 1.0 means some experts have more
746- # tokens than the perfect distribution.
747- # - < 1.0 does not make sense.
748- imbalance_factor = 1.3
749- # Calculate the number of tokens per expert
750- # assuming perfect distribution.
751- num_tokens_per_expert = (num_tokens * top_k ) // self .num_experts
752- # Apply the imbalance factor.
753- num_tokens_per_expert = int (num_tokens_per_expert * imbalance_factor )
754- # And pad the number to the next power of 2.
755- tile_tokens_dim = next_power_of_2 (num_tokens_per_expert )
756- # Cap to 8-64 tokens per CTA tile
757- # as it's the range supported by the kernel.
758- tile_tokens_dim = min (max (tile_tokens_dim , 8 ), 64 )
759-
760- return tile_tokens_dim
761-
762731 def get_fused_moe_quant_config (
763732 self , layer : torch .nn .Module
764733 ) -> FusedMoEQuantConfig | None :
@@ -1037,7 +1006,7 @@ def apply(
10371006 layer .ep_rank * layer .local_num_experts , # local_expert_offset
10381007 self .num_experts , # local num experts
10391008 None ,
1040- self . _get_tile_tokens_dim ( x , top_k ) ,
1009+ None ,
10411010 1 if renormalize else 0 , # routing_method_type, renormalize
10421011 True , # do finalize
10431012 tune_max_num_tokens = self .max_capture_size ,
0 commit comments