Skip to content

Commit 41d3071

Browse files
jiahancmgoin
andauthored
[NVIDIA] [Perf] Update to leverage flashinfer trtllm FP4 MOE throughput kernel (#26714)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
1 parent fb5e10d commit 41d3071

File tree

7 files changed

+25
-96
lines changed

7 files changed

+25
-96
lines changed

docker/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,8 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
359359
# Install FlashInfer pre-compiled kernel cache and binaries
360360
# https://docs.flashinfer.ai/installation.html
361361
RUN --mount=type=cache,target=/root/.cache/uv \
362-
uv pip install --system flashinfer-cubin==0.4.0 \
363-
&& uv pip install --system flashinfer-jit-cache==0.4.0 \
362+
uv pip install --system flashinfer-cubin==0.4.1 \
363+
&& uv pip install --system flashinfer-jit-cache==0.4.1 \
364364
--extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
365365
&& flashinfer show-config
366366

docker/Dockerfile.nightly_torch

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,15 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2.
246246

247247

248248
# build flashinfer for torch nightly from source around 10 mins
249-
# release version: v0.4.0
249+
# release version: v0.4.1
250250
# todo(elainewy): cache flashinfer build result for faster build
251251
ENV CCACHE_DIR=/root/.cache/ccache
252252
RUN --mount=type=cache,target=/root/.cache/ccache \
253253
--mount=type=cache,target=/root/.cache/uv \
254254
echo "git clone flashinfer..." \
255255
&& git clone --recursive https://github.com/flashinfer-ai/flashinfer.git \
256256
&& cd flashinfer \
257-
&& git checkout v0.4.0 \
257+
&& git checkout v0.4.1\
258258
&& git submodule update --init --recursive \
259259
&& echo "finish git clone flashinfer..." \
260260
&& rm -rf build \

requirements/cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@ torchvision==0.23.0 # Required for phi3v processor. See https://github.com/pytor
1212
# https://github.com/facebookresearch/xformers/releases/tag/v0.0.32.post1
1313
xformers==0.0.32.post1; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.8
1414
# FlashInfer should be updated together with the Dockerfile
15-
flashinfer-python==0.4.0
15+
flashinfer-python==0.4.1

tests/kernels/moe/test_ocp_mx_moe.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
trtllm_fp4_block_scale_moe,
3838
)
3939
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
40-
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
40+
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
4141

4242

4343
@dataclass
@@ -319,7 +319,7 @@ def tg_mxfp4_moe(
319319
if transpose_optimized:
320320
for i in range(num_experts):
321321
# w13 weight shuffling
322-
permute_indices = _maybe_get_cached_w2_permute_indices(
322+
permute_indices = get_w2_permute_indices_with_cache(
323323
_cache_permute_indices,
324324
w13_weight[i].view(torch.uint8),
325325
epilogue_tile_m,
@@ -330,7 +330,7 @@ def tg_mxfp4_moe(
330330
.contiguous()
331331
)
332332
# w13 scale shuffling
333-
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
333+
permute_sf_indices = get_w2_permute_indices_with_cache(
334334
_cache_permute_indices,
335335
w13_weight_scale[i].view(torch.uint8),
336336
epilogue_tile_m,
@@ -344,7 +344,7 @@ def tg_mxfp4_moe(
344344
)
345345
)
346346
# w13 bias shuffling
347-
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
347+
permute_bias_indices = get_w2_permute_indices_with_cache(
348348
_cache_permute_indices,
349349
w13_bias[i].clone().reshape(-1, 1),
350350
epilogue_tile_m,
@@ -356,7 +356,7 @@ def tg_mxfp4_moe(
356356
.contiguous()
357357
)
358358
# w2 weight shuffling
359-
permute_indices = _maybe_get_cached_w2_permute_indices(
359+
permute_indices = get_w2_permute_indices_with_cache(
360360
_cache_permute_indices,
361361
w2_weight[i].view(torch.uint8),
362362
epilogue_tile_m,
@@ -367,7 +367,7 @@ def tg_mxfp4_moe(
367367
.contiguous()
368368
)
369369
# w2 scale shuffling
370-
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
370+
permute_sf_indices = get_w2_permute_indices_with_cache(
371371
_cache_permute_indices,
372372
w2_weight_scale[i].view(torch.uint8),
373373
epilogue_tile_m,
@@ -381,7 +381,7 @@ def tg_mxfp4_moe(
381381
)
382382
)
383383
# w2 bias shuffling
384-
permute_indices = _maybe_get_cached_w2_permute_indices(
384+
permute_indices = get_w2_permute_indices_with_cache(
385385
_cache_permute_indices,
386386
w2_bias[i].clone().reshape(-1, 1),
387387
epilogue_tile_m,

vllm/model_executor/layers/fused_moe/trtllm_moe.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1212
TopKWeightAndReduceNoOP,
1313
)
14-
from vllm.utils import next_power_of_2
1514

1615

1716
class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
@@ -65,30 +64,6 @@ def workspace_shapes(
6564
output = (M, K)
6665
return (workspace1, workspace2, output)
6766

68-
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int):
69-
# Number of tokens in the input tensor.
70-
num_tokens = x.shape[0]
71-
# Factor to account for the imbalance of the experts.
72-
# factor equals to the
73-
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
74-
# 1.0 means perfect expert distribution.
75-
# > 1.0 means some experts have more tokens than the perfect
76-
# distribution.
77-
# < 1.0 does not make sense.
78-
imbalance_factor = 1.3
79-
# Calculate the number of tokens per expert assuming perfect
80-
# distribution.
81-
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
82-
# Apply the imbalance factor.
83-
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
84-
# And pad the number to the next power of 2.
85-
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
86-
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
87-
# kernel.
88-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
89-
90-
return tile_tokens_dim
91-
9267
def apply(
9368
self,
9469
output: torch.Tensor,
@@ -148,9 +123,7 @@ def apply(
148123
"local_expert_offset": local_expert_offset,
149124
"local_num_experts": local_num_experts,
150125
"routed_scaling_factor": None,
151-
"tile_tokens_dim": self._get_tile_tokens_dim(
152-
x_quant, topk, local_num_experts
153-
),
126+
"tile_tokens_dim": None,
154127
"routing_method_type": 1,
155128
"do_finalize": True,
156129
"output": output,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@
7272
)
7373
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
7474
from vllm.scalar_type import scalar_types
75-
from vllm.utils import next_power_of_2
7675
from vllm.utils.flashinfer import (
7776
flashinfer_scaled_fp4_mm,
7877
has_flashinfer,
@@ -1125,16 +1124,6 @@ def apply(
11251124
return out.view(*output_shape)
11261125

11271126

1128-
def _get_tile_tokens_dim(num_tokens: int, top_k: int, num_experts: int) -> int:
1129-
# Guess tokens per expert assuming perfect expert distribution first.
1130-
num_tokens_per_expert = (num_tokens * top_k) // num_experts
1131-
# And pad the number to the next power of 2.
1132-
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
1133-
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
1134-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
1135-
return tile_tokens_dim
1136-
1137-
11381127
class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
11391128
"""
11401129
MoE Method for FP4 Quantization.
@@ -1332,8 +1321,8 @@ def prepare_static_weights_for_trtllm_fp4_moe(
13321321
):
13331322
from flashinfer import nvfp4_block_scale_interleave
13341323
from flashinfer.fused_moe.core import (
1335-
_maybe_get_cached_w2_permute_indices,
13361324
_maybe_get_cached_w3_w1_permute_indices,
1325+
get_w2_permute_indices_with_cache,
13371326
)
13381327

13391328
"""Prepare quantized weights for kernel (done offline with weights)."""
@@ -1394,7 +1383,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
13941383
)
13951384
)
13961385

1397-
permute_indices = _maybe_get_cached_w2_permute_indices(
1386+
permute_indices = get_w2_permute_indices_with_cache(
13981387
self._cache_permute_indices,
13991388
gemm2_weights_fp4[i].view(torch.uint8),
14001389
epilogue_tile_m,
@@ -1405,7 +1394,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
14051394
.contiguous()
14061395
)
14071396

1408-
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
1397+
permute_sf_indices = get_w2_permute_indices_with_cache(
14091398
self._cache_permute_indices,
14101399
gemm2_scales_linear_fp4[i].view(torch.uint8),
14111400
epilogue_tile_m,
@@ -1664,9 +1653,7 @@ def apply(
16641653
local_expert_offset=layer.ep_rank * layer.local_num_experts,
16651654
local_num_experts=layer.local_num_experts,
16661655
routed_scaling_factor=None,
1667-
tile_tokens_dim=_get_tile_tokens_dim(
1668-
x.shape[0], top_k, layer.local_num_experts
1669-
),
1656+
tile_tokens_dim=None,
16701657
routing_method_type=routing_method_type,
16711658
do_finalize=True,
16721659
)[0]

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
from vllm.utils import (
5151
has_triton_kernels,
5252
is_torch_equal_or_newer,
53-
next_power_of_2,
5453
round_up,
5554
)
5655
from vllm.utils.flashinfer import has_flashinfer
@@ -97,12 +96,6 @@ def get_mxfp4_backend():
9796
and has_flashinfer()
9897
and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
9998
):
100-
logger.info_once(
101-
"Using FlashInfer MXFP4 MXFP8 TRTLLM backend for SM100, "
102-
"for high concurrency throughput workloads consider setting "
103-
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS=1 for better "
104-
"performance"
105-
)
10699
return Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
107100
elif current_platform.is_device_capability(100) and has_flashinfer():
108101
logger.info_once(
@@ -357,7 +350,7 @@ def process_weights_after_loading(self, layer):
357350
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
358351
):
359352
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
360-
from flashinfer.fused_moe.core import _maybe_get_cached_w2_permute_indices
353+
from flashinfer.fused_moe.core import get_w2_permute_indices_with_cache
361354

362355
layer.gemm1_alpha = Parameter(
363356
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
@@ -449,7 +442,7 @@ def swap_every_two_rows(x, axis=-1):
449442
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
450443
for i in range(self.num_experts):
451444
# w13 weight shuffling
452-
permute_indices = _maybe_get_cached_w2_permute_indices(
445+
permute_indices = get_w2_permute_indices_with_cache(
453446
self._cache_permute_indices,
454447
w13_weight[i].view(torch.uint8),
455448
epilogue_tile_m,
@@ -460,7 +453,7 @@ def swap_every_two_rows(x, axis=-1):
460453
.contiguous()
461454
)
462455
# w13 scale shuffling
463-
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
456+
permute_sf_indices = get_w2_permute_indices_with_cache(
464457
self._cache_permute_indices,
465458
w13_weight_scale[i].view(torch.uint8),
466459
epilogue_tile_m,
@@ -476,7 +469,7 @@ def swap_every_two_rows(x, axis=-1):
476469
)
477470
)
478471
# w13 bias shuffling
479-
permute_bias_indices = _maybe_get_cached_w2_permute_indices(
472+
permute_bias_indices = get_w2_permute_indices_with_cache(
480473
self._cache_permute_indices,
481474
w13_bias[i].clone().reshape(-1, 1),
482475
epilogue_tile_m,
@@ -488,7 +481,7 @@ def swap_every_two_rows(x, axis=-1):
488481
.contiguous()
489482
)
490483
# w2 weight shuffling
491-
permute_indices = _maybe_get_cached_w2_permute_indices(
484+
permute_indices = get_w2_permute_indices_with_cache(
492485
self._cache_permute_indices,
493486
w2_weight[i].view(torch.uint8),
494487
epilogue_tile_m,
@@ -499,7 +492,7 @@ def swap_every_two_rows(x, axis=-1):
499492
.contiguous()
500493
)
501494
# w2 scale shuffling
502-
permute_sf_indices = _maybe_get_cached_w2_permute_indices(
495+
permute_sf_indices = get_w2_permute_indices_with_cache(
503496
self._cache_permute_indices,
504497
w2_weight_scale[i].view(torch.uint8),
505498
epilogue_tile_m,
@@ -515,7 +508,7 @@ def swap_every_two_rows(x, axis=-1):
515508
)
516509
)
517510
# w2 bias shuffling
518-
permute_indices = _maybe_get_cached_w2_permute_indices(
511+
permute_indices = get_w2_permute_indices_with_cache(
519512
self._cache_permute_indices,
520513
w2_bias[i].clone().reshape(-1, 1),
521514
epilogue_tile_m,
@@ -735,30 +728,6 @@ def _interleave_mxfp4_cutlass_sm90(w):
735728
else:
736729
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
737730

738-
def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int):
739-
# Number of tokens in the input tensor.
740-
num_tokens = x.shape[0]
741-
# Factor to account for the imbalance of the experts.
742-
# factor equals to the
743-
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
744-
# - 1.0 means perfect expert distribution.
745-
# - > 1.0 means some experts have more
746-
# tokens than the perfect distribution.
747-
# - < 1.0 does not make sense.
748-
imbalance_factor = 1.3
749-
# Calculate the number of tokens per expert
750-
# assuming perfect distribution.
751-
num_tokens_per_expert = (num_tokens * top_k) // self.num_experts
752-
# Apply the imbalance factor.
753-
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
754-
# And pad the number to the next power of 2.
755-
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
756-
# Cap to 8-64 tokens per CTA tile
757-
# as it's the range supported by the kernel.
758-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
759-
760-
return tile_tokens_dim
761-
762731
def get_fused_moe_quant_config(
763732
self, layer: torch.nn.Module
764733
) -> FusedMoEQuantConfig | None:
@@ -1037,7 +1006,7 @@ def apply(
10371006
layer.ep_rank * layer.local_num_experts, # local_expert_offset
10381007
self.num_experts, # local num experts
10391008
None,
1040-
self._get_tile_tokens_dim(x, top_k),
1009+
None,
10411010
1 if renormalize else 0, # routing_method_type, renormalize
10421011
True, # do finalize
10431012
tune_max_num_tokens=self.max_capture_size,

0 commit comments

Comments
 (0)