Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions csrc/moe/moe_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ __global__ void moe_align_block_size_kernel(

for (size_t i = tid; i < numel; i += stride) {
int expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
Expand Down Expand Up @@ -95,12 +98,15 @@ template <typename scalar_t>
__global__ void count_and_sort_expert_tokens_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
size_t numel) {
size_t numel, int32_t num_experts) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;

for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
if (expert_id >= num_experts) {
continue;
}
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
}
Expand Down Expand Up @@ -269,7 +275,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel(), num_experts);
}
});
}
Expand Down
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ th {
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ |
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ |
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ |
| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ |✅︎ | ✅︎ |

Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!

Expand Down
4 changes: 2 additions & 2 deletions tests/kernels/moe/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)

score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)

score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=td.hidden_states,
router_logits=score,
use_grouped_topk=False,
Expand Down
6 changes: 6 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,8 @@ def check_available_online(
is_available_online=False),
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
is_available_online=False),
"LongcatFlashForCausalLM": _HfExamplesInfo
("meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
min_transformers_version="4.55.3",
Expand Down Expand Up @@ -639,6 +641,10 @@ def check_available_online(
speculative_model="zai-org/GLM-4.5",
min_transformers_version="4.54",
is_available_online=False),
"LongCatFlashMTPModel": _HfExamplesInfo(
"meituan-longcat/LongCat-Flash-Chat",
trust_remote_code=True,
speculative_model="meituan-longcat/LongCat-Flash-Chat"),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
Expand Down
12 changes: 9 additions & 3 deletions tests/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,8 @@ def dummy_hf_overrides(
num_hidden_layers = (3 if model_arch
== "Gemma3nForConditionalGeneration" else 1)

text_config.update({
update_dict = {
"num_layers": num_layers,
"num_hidden_layers": num_hidden_layers,
"num_experts": num_experts,
"num_experts_per_tok": 2,
"num_local_experts": num_experts,
Expand All @@ -435,7 +434,14 @@ def dummy_hf_overrides(
"n_routed_experts": num_experts,
# For Gemma-3n
"num_kv_shared_layers": 1,
})
}

# Update num_hidden_layers for non-Longcat architectures
if model_arch != "LongcatFlashForCausalLM" \
and model_arch != "LongCatFlashMTPModel":
update_dict["num_hidden_layers"] = num_hidden_layers

text_config.update(update_dict)

if hasattr(hf_config, "vision_config"):
hf_config.vision_config.update({
Expand Down
2 changes: 1 addition & 1 deletion tests/test_routing_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_routing_strategy_integration(monkeypatch, device):
envs.environment_variables[env_name] = lambda s=strategy: s

# Test the select_experts method
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids, _ = FusedMoE.select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
Expand Down
6 changes: 5 additions & 1 deletion vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,8 @@ def is_deepseek_mla(self) -> bool:
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'):
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
'kimi_k2', 'longcat_flash'):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the
Expand Down Expand Up @@ -1275,6 +1276,9 @@ def get_layers_start_end_indices(
or self.hf_config.model_type == "qwen3_next_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
elif (self.hf_config.model_type == "longcat_flash_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 1)
else:
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
Expand Down
21 changes: 19 additions & 2 deletions vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
"longcat_flash_mtp"]


@config
Expand Down Expand Up @@ -186,6 +187,13 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
if hf_config.model_type == "longcat_flash":
hf_config.model_type = "longcat_flash_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update({
"n_predict": n_predict,
"architectures": ["LongCatFlashMTPModel"]
})

return hf_config

Expand Down Expand Up @@ -334,6 +342,15 @@ def __post_init__(self):
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")):
self.method = "longcat_flash_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"LongCat MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
Expand Down Expand Up @@ -550,7 +567,7 @@ def num_lookahead_slots(self) -> int:

def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")
"qwen3_next_mtp", "longcat_flash_mtp")

def __repr__(self) -> str:
method = self.method
Expand Down
89 changes: 89 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,76 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
)


@triton.jit
def compute_identity_kernel(
top_k: int,
hidden_states_ptr: tl.tensor,
expert_scales_ptr: tl.tensor,
num_tokens: int,
output_ptr: tl.tensor,
hidden_dim: int,
scales_stride: int,
BLOCK_SIZE: tl.constexpr,
) -> None:
pid = tl.program_id(0)

batch_id = pid // (hidden_dim // BLOCK_SIZE)
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE

if batch_id >= num_tokens or dim_offset >= hidden_dim:
return

h = tl.load(hidden_states_ptr + batch_id * hidden_dim + dim_offset +
tl.arange(0, BLOCK_SIZE),
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)

result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for i in range(top_k):
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
result += h * scale

tl.store(output_ptr + batch_id * hidden_dim + dim_offset +
tl.arange(0, BLOCK_SIZE),
result,
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)


def zero_experts_compute_triton(expert_indices: torch.Tensor,
expert_scales: torch.Tensor, num_experts: int,
zero_expert_type: str,
hidden_states: torch.Tensor) -> torch.Tensor:
N = expert_indices.numel()
top_k = expert_indices.size(-1)
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )

if zero_expert_type == "identity":
zero_expert_mask = expert_indices < num_experts
zero_expert_scales = expert_scales.clone()
zero_expert_scales[zero_expert_mask] = 0.0

normal_expert_mask = expert_indices >= num_experts
expert_indices[normal_expert_mask] = 0
expert_scales[normal_expert_mask] = 0.0
Comment on lines +715 to +716
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function zero_experts_compute_triton modifies its input tensors expert_indices and expert_scales in-place. This side effect can be unexpected and lead to bugs if the caller reuses these tensors assuming they are unchanged. While this might be an intentional optimization to avoid extra memory allocations, it makes the code harder to reason about and maintain. To improve clarity and safety, consider returning the modified tensors instead of modifying them in-place. This would make the data flow explicit.


output = torch.zeros_like(hidden_states).to(hidden_states.device)
hidden_dim = hidden_states.size(-1)
num_tokens = hidden_states.size(0)

grid = lambda meta: (num_tokens * (hidden_dim // meta['BLOCK_SIZE']), )
compute_identity_kernel[grid](
top_k,
hidden_states,
zero_expert_scales,
num_tokens,
output,
hidden_dim,
zero_expert_scales.stride(0),
BLOCK_SIZE=256,
)

return output


# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name(E: int,
N: int,
Expand Down Expand Up @@ -940,6 +1010,25 @@ def fused_topk(
return topk_weights, topk_ids, token_expert_indices


def fused_topk_bias(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
):
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts) + e_score_correction_bias.unsqueeze(0)
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1,
sorted=False)[1]
topk_weights = scores.gather(1, topk_indices)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)


# This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def grouped_topk(
Expand Down
Loading