Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
7cdb800
Flashinfer cutlass moe backend for TP/DP + EP.
wenscarl Jul 2, 2025
11f9136
cutlass mooe path work
wenscarl Jul 4, 2025
92994d5
flashinfer tp pass
wenscarl Jul 4, 2025
93a3a0a
dp work
wenscarl Jul 5, 2025
c1a74eb
Clean up
wenscarl Jul 5, 2025
f4dc86b
Fix NoEP class
wenscarl Jul 7, 2025
7df49e1
Fix NoEP class
wenscarl Jul 7, 2025
5af3db9
chunking work
wenscarl Jul 9, 2025
4c5fa6d
Address comments
wenscarl Jul 9, 2025
f99cf65
minor fix
wenscarl Jul 9, 2025
dbefd52
cutlass_moe_fp4 support TP chunking
wenscarl Jul 10, 2025
3f043dd
flahinfer cutlass moe support TP chunking
wenscarl Jul 10, 2025
7b5e203
Address comment and clean Up
wenscarl Jul 10, 2025
248fff3
Merge remote-tracking branch 'origin/main' into flashinfer_fused_moe
wenscarl Jul 11, 2025
131f141
Upd
wenscarl Jul 11, 2025
e5403e5
Merge remote-tracking branch 'origin/main' into flashinfer_fused_moe
wenscarl Jul 12, 2025
3bdbeb1
Fix lint
wenscarl Jul 12, 2025
8fdb3dc
Merge Pynccl ag/rs
wenscarl Jul 12, 2025
5bcd50b
Recover apply_router_weight_on_input
wenscarl Jul 15, 2025
b842245
Merge remote-tracking branch 'origin/main' into flashinfer_fused_moe
wenscarl Jul 15, 2025
6fed494
remove comment
wenscarl Jul 15, 2025
20f3417
fix lint
wenscarl Jul 15, 2025
728275b
Add autotune
wenscarl Jul 15, 2025
bbb505f
Add flashinfer wrapper and fix pre-commit
mgoin Jul 16, 2025
841fcd1
Move switch w13 to modelopt
wenscarl Jul 16, 2025
716621a
Address comments.
wenscarl Jul 17, 2025
f31be5a
Upd
wenscarl Jul 17, 2025
6225670
Upd
wenscarl Jul 17, 2025
a08b47f
fix interface changes
wenscarl Jul 17, 2025
881684f
Merge remote-tracking branch 'origin/main' into flashinfer_fused_moe
wenscarl Jul 17, 2025
24ce9aa
Upd
wenscarl Jul 17, 2025
e5d14ea
Fix run_cutlass_moe_fp4 for Llama 4
mgoin Jul 17, 2025
4b3ee2e
Ensure lazy imports for flashinfer
mgoin Jul 17, 2025
cc0e87d
Just use has_flashinfer in fused_moe/layer.py
mgoin Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,11 +957,11 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
c_strides, per_act_token, per_out_ch)


def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
a_scales: torch.Tensor, b_scales: torch.Tensor,
alphas: torch.Tensor, problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor,
out_dtype: torch.dtype, device: torch.device):
def cutlass_fp4_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor,
b_tensors: torch.Tensor, a_scales: torch.Tensor,
b_scales: torch.Tensor, alphas: torch.Tensor,
problem_sizes: torch.Tensor,
expert_offsets: torch.Tensor, sf_offsets: torch.Tensor):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
the gemms for each combination based on the specified problem sizes.
Expand All @@ -978,14 +978,10 @@ def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor,
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
"""
m_topk = a_tensors.shape[0]
n = b_tensors.shape[1]
c_shape = (m_topk, n)
c = torch.empty(c_shape, device=device, dtype=out_dtype)
torch.ops._C.cutlass_fp4_group_mm(c, a_tensors, b_tensors, a_scales,
b_scales, alphas, problem_sizes,
expert_offsets, sf_offsets)
return c.to(out_dtype)
return torch.ops._C.cutlass_fp4_group_mm(out_tensors, a_tensors, b_tensors,
a_scales, b_scales, alphas,
problem_sizes, expert_offsets,
sf_offsets)


# aqlm
Expand Down
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
VLLM_USE_DEEP_GEMM: bool = False
VLLM_USE_FLASHINFER_MOE: bool = False
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
Expand Down Expand Up @@ -868,6 +869,10 @@ def get_vllm_port() -> Optional[int]:
"VLLM_USE_DEEP_GEMM":
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),

# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
"VLLM_USE_FLASHINFER_MOE":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),

# Control the cache sized used by the xgrammar compiler. The default
# of 512 MB should be enough for roughly 1000 JSON schemas.
# It can be changed with this variable if needed for some reason.
Expand Down
36 changes: 13 additions & 23 deletions vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Any, Optional

import torch

Expand Down Expand Up @@ -255,28 +255,18 @@ def workspace_shapes(
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)

def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
):
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Any, Optional

import torch

Expand Down Expand Up @@ -142,12 +142,13 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool):
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]]):
experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
activation, global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_tokens_meta,
apply_router_weight_on_input)
apply_router_weight_on_input, extra_expert_args)
16 changes: 16 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import cdiv
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe

logger = init_logger(__name__)

Expand Down Expand Up @@ -188,6 +189,11 @@ def use_deepep_ll_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")

@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE
and has_flashinfer_cutlass_fused_moe())

@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
Expand Down Expand Up @@ -392,6 +398,10 @@ def use_deepep_ht_kernels(self):
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels

@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels

@staticmethod
def make(
num_experts: int,
Expand Down Expand Up @@ -435,6 +445,12 @@ def make(
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn

from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
ModelOptNvFp4Config):
quant_dtype = torch.uint8

if weight_quant is not None:
per_out_ch_quant = (
weight_quant.strategy == QuantizationStrategy.CHANNEL)
Expand Down
Loading