Skip to content

Commit 9ed75dd

Browse files
OftenDreamyangxurui
authored andcommitted
[Model] Add LongCat-Flash (vllm-project#23991)
Signed-off-by: yangxurui <yangxurui@meituan.com> Co-authored-by: yangxurui <yangxurui@meituan.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 4d805d4 commit 9ed75dd

31 files changed

+1357
-66
lines changed

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ __global__ void moe_align_block_size_kernel(
4444

4545
for (size_t i = tid; i < numel; i += stride) {
4646
int expert_id = topk_ids[i];
47+
if (expert_id >= num_experts) {
48+
continue;
49+
}
4750
int warp_idx = expert_id / experts_per_warp;
4851
int expert_offset = expert_id % experts_per_warp;
4952
atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1);
@@ -95,12 +98,15 @@ template <typename scalar_t>
9598
__global__ void count_and_sort_expert_tokens_kernel(
9699
const scalar_t* __restrict__ topk_ids,
97100
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer,
98-
size_t numel) {
101+
size_t numel, int32_t num_experts) {
99102
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
100103
const size_t stride = blockDim.x * gridDim.x;
101104

102105
for (size_t i = tid; i < numel; i += stride) {
103106
int32_t expert_id = topk_ids[i];
107+
if (expert_id >= num_experts) {
108+
continue;
109+
}
104110
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
105111
sorted_token_ids[rank_post_pad] = i;
106112
}
@@ -269,7 +275,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
269275
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
270276
topk_ids.data_ptr<scalar_t>(),
271277
sorted_token_ids.data_ptr<int32_t>(),
272-
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel());
278+
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel(), num_experts);
273279
}
274280
});
275281
}

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ th {
428428
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | ✅︎ |
429429
| `MiniMaxText01ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-Text-01`, etc. | | | ✅︎ |
430430
| `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | ✅︎ |
431+
| `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ |✅︎ | ✅︎ |
431432

432433
Some models are supported only via the [Transformers backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it!
433434

tests/kernels/moe/test_flashinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
138138
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
139139

140140
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
141-
topk_weights, topk_ids = FusedMoE.select_experts(
141+
topk_weights, topk_ids, _ = FusedMoE.select_experts(
142142
hidden_states=td.hidden_states,
143143
router_logits=score,
144144
use_grouped_topk=False,
@@ -206,7 +206,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
206206
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
207207

208208
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
209-
topk_weights, topk_ids = FusedMoE.select_experts(
209+
topk_weights, topk_ids, _ = FusedMoE.select_experts(
210210
hidden_states=td.hidden_states,
211211
router_logits=score,
212212
use_grouped_topk=False,

tests/models/registry.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,8 @@ def check_available_online(
273273
is_available_online=False),
274274
"Llama4ForCausalLM": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct", # noqa: E501
275275
is_available_online=False),
276+
"LongcatFlashForCausalLM": _HfExamplesInfo
277+
("meituan-longcat/LongCat-Flash-Chat", trust_remote_code=True),
276278
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
277279
"Mamba2ForCausalLM": _HfExamplesInfo("mistralai/Mamba-Codestral-7B-v0.1",
278280
min_transformers_version="4.55.3",
@@ -639,6 +641,10 @@ def check_available_online(
639641
speculative_model="zai-org/GLM-4.5",
640642
min_transformers_version="4.54",
641643
is_available_online=False),
644+
"LongCatFlashMTPModel": _HfExamplesInfo(
645+
"meituan-longcat/LongCat-Flash-Chat",
646+
trust_remote_code=True,
647+
speculative_model="meituan-longcat/LongCat-Flash-Chat"),
642648
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
643649
trust_remote_code=True,
644650
speculative_model="XiaomiMiMo/MiMo-7B-RL"),

tests/models/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,8 @@ def dummy_hf_overrides(
428428
num_hidden_layers = (3 if model_arch
429429
== "Gemma3nForConditionalGeneration" else 1)
430430

431-
text_config.update({
431+
update_dict = {
432432
"num_layers": num_layers,
433-
"num_hidden_layers": num_hidden_layers,
434433
"num_experts": num_experts,
435434
"num_experts_per_tok": 2,
436435
"num_local_experts": num_experts,
@@ -440,7 +439,14 @@ def dummy_hf_overrides(
440439
"n_routed_experts": num_experts,
441440
# For Gemma-3n
442441
"num_kv_shared_layers": 1,
443-
})
442+
}
443+
444+
# Update num_hidden_layers for non-Longcat architectures
445+
if model_arch != "LongcatFlashForCausalLM" \
446+
and model_arch != "LongCatFlashMTPModel":
447+
update_dict["num_hidden_layers"] = num_hidden_layers
448+
449+
text_config.update(update_dict)
444450

445451
if hasattr(hf_config, "vision_config"):
446452
hf_config.vision_config.update({

tests/test_routing_simulator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_routing_strategy_integration(monkeypatch, device):
9696
envs.environment_variables[env_name] = lambda s=strategy: s
9797

9898
# Test the select_experts method
99-
topk_weights, topk_ids = FusedMoE.select_experts(
99+
topk_weights, topk_ids, _ = FusedMoE.select_experts(
100100
hidden_states=hidden_states,
101101
router_logits=router_logits,
102102
top_k=top_k,

vllm/config/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,8 @@ def is_deepseek_mla(self) -> bool:
11311131
if not hasattr(self.hf_text_config, "model_type"):
11321132
return False
11331133
elif self.hf_text_config.model_type in \
1134-
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'):
1134+
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
1135+
'kimi_k2', 'longcat_flash'):
11351136
return self.hf_text_config.kv_lora_rank is not None
11361137
elif self.hf_text_config.model_type == 'eagle':
11371138
# if the model is an EAGLE module, check for the
@@ -1257,6 +1258,9 @@ def get_layers_start_end_indices(
12571258
or self.hf_config.model_type == "qwen3_next_mtp"):
12581259
total_num_hidden_layers = getattr(self.hf_text_config,
12591260
"num_nextn_predict_layers", 0)
1261+
elif (self.hf_config.model_type == "longcat_flash_mtp"):
1262+
total_num_hidden_layers = getattr(self.hf_text_config,
1263+
"num_nextn_predict_layers", 1)
12601264
else:
12611265
total_num_hidden_layers = getattr(self.hf_text_config,
12621266
"num_hidden_layers", 0)

vllm/config/speculative.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131

3232
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
3333
"mlp_speculator", "draft_model", "deepseek_mtp",
34-
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp"]
34+
"ernie_mtp", "qwen3_next_mtp", "mimo_mtp",
35+
"longcat_flash_mtp"]
3536

3637

3738
@config
@@ -186,6 +187,13 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
186187
"n_predict": n_predict,
187188
"architectures": ["Qwen3NextMTP"]
188189
})
190+
if hf_config.model_type == "longcat_flash":
191+
hf_config.model_type = "longcat_flash_mtp"
192+
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
193+
hf_config.update({
194+
"n_predict": n_predict,
195+
"architectures": ["LongCatFlashMTPModel"]
196+
})
189197

190198
return hf_config
191199

@@ -332,6 +340,15 @@ def __post_init__(self):
332340
"one layer. Might need some code changes " \
333341
"to support multiple layers."
334342
)
343+
elif (self.draft_model_config.hf_config.model_type
344+
in ("longcat_flash_mtp")):
345+
self.method = "longcat_flash_mtp"
346+
if self.num_speculative_tokens > 1:
347+
logger.warning(
348+
"LongCat MTP models only have " \
349+
"one layer. Might need some code changes " \
350+
"to support multiple layers."
351+
)
335352
else:
336353
self.method = "draft_model"
337354
raise NotImplementedError(
@@ -548,7 +565,7 @@ def num_lookahead_slots(self) -> int:
548565

549566
def use_eagle(self) -> bool:
550567
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
551-
"qwen3_next_mtp")
568+
"qwen3_next_mtp", "longcat_flash_mtp")
552569

553570
def __repr__(self) -> str:
554571
method = self.method

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,76 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
664664
)
665665

666666

667+
@triton.jit
668+
def compute_identity_kernel(
669+
top_k: int,
670+
hidden_states_ptr: tl.tensor,
671+
expert_scales_ptr: tl.tensor,
672+
num_tokens: int,
673+
output_ptr: tl.tensor,
674+
hidden_dim: int,
675+
scales_stride: int,
676+
BLOCK_SIZE: tl.constexpr,
677+
) -> None:
678+
pid = tl.program_id(0)
679+
680+
batch_id = pid // (hidden_dim // BLOCK_SIZE)
681+
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
682+
683+
if batch_id >= num_tokens or dim_offset >= hidden_dim:
684+
return
685+
686+
h = tl.load(hidden_states_ptr + batch_id * hidden_dim + dim_offset +
687+
tl.arange(0, BLOCK_SIZE),
688+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
689+
690+
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
691+
for i in range(top_k):
692+
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
693+
result += h * scale
694+
695+
tl.store(output_ptr + batch_id * hidden_dim + dim_offset +
696+
tl.arange(0, BLOCK_SIZE),
697+
result,
698+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim)
699+
700+
701+
def zero_experts_compute_triton(expert_indices: torch.Tensor,
702+
expert_scales: torch.Tensor, num_experts: int,
703+
zero_expert_type: str,
704+
hidden_states: torch.Tensor) -> torch.Tensor:
705+
N = expert_indices.numel()
706+
top_k = expert_indices.size(-1)
707+
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']), )
708+
709+
if zero_expert_type == "identity":
710+
zero_expert_mask = expert_indices < num_experts
711+
zero_expert_scales = expert_scales.clone()
712+
zero_expert_scales[zero_expert_mask] = 0.0
713+
714+
normal_expert_mask = expert_indices >= num_experts
715+
expert_indices[normal_expert_mask] = 0
716+
expert_scales[normal_expert_mask] = 0.0
717+
718+
output = torch.zeros_like(hidden_states).to(hidden_states.device)
719+
hidden_dim = hidden_states.size(-1)
720+
num_tokens = hidden_states.size(0)
721+
722+
grid = lambda meta: (num_tokens * (hidden_dim // meta['BLOCK_SIZE']), )
723+
compute_identity_kernel[grid](
724+
top_k,
725+
hidden_states,
726+
zero_expert_scales,
727+
num_tokens,
728+
output,
729+
hidden_dim,
730+
zero_expert_scales.stride(0),
731+
BLOCK_SIZE=256,
732+
)
733+
734+
return output
735+
736+
667737
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
668738
def get_config_file_name(E: int,
669739
N: int,
@@ -940,6 +1010,25 @@ def fused_topk(
9401010
return topk_weights, topk_ids, token_expert_indices
9411011

9421012

1013+
def fused_topk_bias(
1014+
hidden_states: torch.Tensor,
1015+
gating_output: torch.Tensor,
1016+
e_score_correction_bias: torch.Tensor,
1017+
topk: int,
1018+
renormalize: bool,
1019+
):
1020+
n_routed_experts = gating_output.shape[-1]
1021+
scores = gating_output.softmax(dim=-1)
1022+
scores_for_choice = scores.view(
1023+
-1, n_routed_experts) + e_score_correction_bias.unsqueeze(0)
1024+
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1,
1025+
sorted=False)[1]
1026+
topk_weights = scores.gather(1, topk_indices)
1027+
if renormalize:
1028+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
1029+
return topk_weights.to(torch.float32), topk_indices.to(torch.int32)
1030+
1031+
9431032
# This is used by the Deepseek-V2 and Deepseek-V3 model
9441033
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
9451034
def grouped_topk(

0 commit comments

Comments
 (0)