Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .yapfignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
collect_env.py
vllm/model_executor/layers/fla/ops/*.py
1 change: 1 addition & 0 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ th {
| `Qwen2MoeForCausalLM` | Qwen2MoE | `Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3ForCausalLM` | Qwen3 | `Qwen/Qwen3-8B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3MoeForCausalLM` | Qwen3MoE | `Qwen/Qwen3-30B-A3B`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Qwen3NextForCausalLM` | Qwen3.5MoE | `Qwen/Qwen3-Next-80B-A3B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `SeedOssForCausalLM` | SeedOss | `ByteDance-Seed/Seed-OSS-36B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `StableLmForCausalLM` | StableLM | `stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc. | | | ✅︎ |
| `Starcoder2ForCausalLM` | Starcoder2 | `bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc. | | ✅︎ | ✅︎ |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ fo = "fo"
ba = "ba"

[tool.typos.type.py.extend-words]
ba = "ba"

[tool.typos.type.cpp]
extend-glob = ["*.cu"]
Expand Down
6 changes: 5 additions & 1 deletion tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ def check_available_online(
"Qwen2MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen1.5-MoE-A2.7B-Chat"),
"Qwen3ForCausalLM": _HfExamplesInfo("Qwen/Qwen3-8B"),
"Qwen3MoeForCausalLM": _HfExamplesInfo("Qwen/Qwen3-30B-A3B"),
"Qwen3NextForCausalLM": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
"RWForCausalLM": _HfExamplesInfo("tiiuae/falcon-40b"),
"SeedOssForCausalLM": _HfExamplesInfo("ByteDance-Seed/Seed-OSS-36B-Instruct", # noqa: E501
trust_remote_code=True,
Expand Down Expand Up @@ -637,7 +639,9 @@ def check_available_online(
is_available_online=False),
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
trust_remote_code=True,
speculative_model="XiaomiMiMo/MiMo-7B-RL")
speculative_model="XiaomiMiMo/MiMo-7B-RL"),
"Qwen3NextMTP": _HfExamplesInfo("Qwen/Qwen3-Next-80B-A3B-Instruct",
min_transformers_version="4.56.2"),
}

_TRANSFORMERS_BACKEND_MODELS = {
Expand Down
56 changes: 46 additions & 10 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,8 @@ def get_layers_start_end_indices(
if (self.hf_text_config.model_type == "deepseek_mtp"
or self.hf_config.model_type == "mimo_mtp"
or self.hf_config.model_type == "glm4_moe_mtp"
or self.hf_config.model_type == "ernie_mtp"):
or self.hf_config.model_type == "ernie_mtp"
or self.hf_config.model_type == "qwen3_next_mtp"):
total_num_hidden_layers = getattr(self.hf_text_config,
"num_nextn_predict_layers", 0)
else:
Expand Down Expand Up @@ -1567,15 +1568,28 @@ def get_num_layers_by_block_type(
if attn_type_list:
return sum(t == 1 for t in attn_type_list[start:end])

if layers_block_type_value is None and attn_type_list is None:
# Hybrid model Qwen3Next
layer_types_value = getattr(self.hf_config, "layer_types", None)
if layer_types_value is not None:
if getattr(block_type, "value", block_type) == "attention":
return sum(t == "full_attention"
for t in layer_types_value[start:end])
elif getattr(block_type, "value",
block_type) == "linear_attention":
return sum(t == "linear_attention"
for t in layer_types_value[start:end])
else:
return sum(t == getattr(block_type, "value", block_type)
for t in layer_types_value[start:end])

if (layers_block_type_value is None and attn_type_list is None
and layer_types_value is None):
raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list in the hf_config,"
"cannot determine the num of "
"layers_block_type or an attn_type_list, or a layer_types "
"in the hf_config, cannot determine the num of "
f"{block_type.value} layers")

return sum(t == 1 for t in attn_type_list[start:end])

def get_mamba_chunk_size(self) -> Optional[int]:
"""
Returns the mamba chunk size if it exists
Expand Down Expand Up @@ -1862,7 +1876,7 @@ def __post_init__(self):

SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"mlp_speculator", "draft_model", "deepseek_mtp",
"ernie_mtp"]
"ernie_mtp", "qwen3_next_mtp"]


@config
Expand Down Expand Up @@ -2003,7 +2017,15 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
return hf_config

if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})

return hf_config

Expand All @@ -2024,9 +2046,13 @@ def __post_init__(self):
(self.target_model_config.hf_text_config.model_type \
== "deepseek_v3" or
self.target_model_config.hf_text_config.model_type in
("mimo","ernie4_5_moe")):
("mimo","ernie4_5_moe", "qwen3_next")):
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
else:
Expand Down Expand Up @@ -2136,6 +2162,15 @@ def __post_init__(self):
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
Comment on lines +2165 to +2173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This elif block for qwen3_next_mtp is nearly identical to the preceding blocks for ernie_mtp and deepseek_mtp. This code duplication makes the code harder to maintain and increases the risk of bugs if one block is updated but the others are not.

Consider refactoring these blocks to reduce duplication. For example:

MTP_MODELS = {
    "deepseek_mtp": ("deepseek_mtp", "Deepseek MTP"),
    "mimo_mtp": ("deepseek_mtp", "Deepseek MTP"),
    "glm4_moe_mtp": ("deepseek_mtp", "Deepseek MTP"),
    "ernie_mtp": ("ernie_mtp", "Ernie MTP"),
    "qwen3_next_mtp": ("qwen3_next_mtp", "Qwen3Next MTP"),
}
model_type = self.draft_model_config.hf_config.model_type
if model_type in MTP_MODELS:
    method, model_name = MTP_MODELS[model_type]
    self.method = method
    if self.num_speculative_tokens > 1:
        logger.warning(
            f"All {model_name} models only have "
            "one layer. Might need some code changes "
            "to support multiple layers."
        )

This would replace lines 2221-2250 and make the code more scalable and maintainable for future MTP models.

else:
self.method = "draft_model"
raise NotImplementedError(
Expand Down Expand Up @@ -2351,7 +2386,8 @@ def num_lookahead_slots(self) -> int:
return self.num_speculative_tokens

def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp")
return self.method in ("eagle", "eagle3", "deepseek_mtp", "ernie_mtp",
"qwen3_next_mtp")

def __repr__(self) -> str:
method = self.method
Expand Down
1 change: 1 addition & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ class CompilationConfig:
"vllm.short_conv",
"vllm.linear_attention",
"vllm.plamo2_mamba_mixer",
"vllm.gdn_attention",
]

def compute_hash(self) -> str:
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.triton_utils import tl, triton

from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp, safe_exp
from .op import exp
from .utils import is_nvidia_hopper, use_cuda_graph

NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
Expand Down Expand Up @@ -175,12 +175,13 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
boundary_check=(0, 1))

if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last
if K > 64:
Expand Down
9 changes: 5 additions & 4 deletions vllm/model_executor/layers/fla/ops/chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.triton_utils import tl, triton

from .index import prepare_chunk_indices
from .op import exp, safe_exp
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper

BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
Expand Down Expand Up @@ -112,10 +112,11 @@ def chunk_fwd_kernel_o(
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
b_A = b_A * exp(b_g[:, None] - b_g[None, :])

o_i = tl.arange(0, BT)
m_A = o_i[:, None] >= o_i[None, :]
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)

p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
Expand Down
10 changes: 6 additions & 4 deletions vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from vllm.triton_utils import tl, triton

from .index import prepare_chunk_indices
from .op import safe_exp
from .op import exp


@triton.heuristics({
Expand Down Expand Up @@ -56,7 +56,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_t = tl.arange(0, BT)
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T

p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
(i_t * BT, ), (BT, ), (0, ))
Expand All @@ -76,9 +77,10 @@ def chunk_scaled_dot_kkt_fwd_kernel(
(i_t * BT, ), (BT, ), (0, ))
b_g = tl.load(p_g, boundary_check=(0, ))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * safe_exp(b_g_diff)
b_A = b_A * exp(b_g_diff)

b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
(i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/fla/ops/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_g = tl.load(p_g).to(tl.float32)

if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fla/ops/l2norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
rindex = tl.arange(0, N)[None, :]
xs = tl.load(X + (rindex + N * row_idx), None).to(tl.float32)
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)
Expand Down
5 changes: 0 additions & 5 deletions vllm/model_executor/layers/fla/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,6 @@ def div_normal(x, y):
log2 = tl.log2


@triton.jit
def safe_exp(x):
return exp(tl.where(x <= 0, x, float('-inf')))


if not hasattr(tl, 'gather'):

@triton.jit
Expand Down
37 changes: 37 additions & 0 deletions vllm/model_executor/layers/mamba/mamba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def short_conv_state_dtype(
model_dtype)
return (conv_state_dtype, )

@classmethod
def gated_delta_net_state_dtype(
cls,
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype)
Comment on lines +79 to +80
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if you need/want it but there is also the mamba_ssm_cache_dtype parameter which gives the option to set the dtype of the temporal state separately to that of the conv state. Some models need this, but if you want to keep them the same that's also fine.



class MambaStateShapeCalculator:

Expand Down Expand Up @@ -163,3 +172,31 @@ def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):

# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups

@classmethod
def gated_delta_net_state_shape(
cls,
tp_world_size: int,
num_k_heads: int,
num_v_heads: int,
head_k_dim: int,
head_v_dim: int,
conv_kernel_size: int,
num_spec: int = 0,
use_v1: bool = True,
):
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
conv_state_shape = (
divide(conv_dim, tp_world_size),
conv_kernel_size - 1 + num_spec,
)

# In V0, the conv_state shape was swapped during allocation in
# MambaCacheManager, but in V1 it needs to be determined here at the
# calculation level
if use_v1:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]

temporal_state_shape = (divide(num_v_heads,
tp_world_size), head_k_dim, head_v_dim)
return conv_state_shape, temporal_state_shape
Loading