Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 87 additions & 62 deletions vllm/model_executor/models/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions
from vllm.distributed import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import (
Expand Down Expand Up @@ -124,6 +125,7 @@
_MAX_FRAMES_PER_VIDEO = 24576


@support_torch_compile(dynamic_arg_dims={"x": 0})
class Qwen3_VisionPatchEmbed(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -187,6 +189,10 @@ def forward(self, x: torch.Tensor):
return mlp_output


@support_torch_compile(
dynamic_arg_dims={"x": 0, "cu_seqlens": 0, "rotary_pos_emb": 0, "seqlens": 0},
mark_unbacked_dims={"seqlens": 0},
)
class Qwen3_VisionBlock(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -246,6 +252,7 @@ def forward(
return x


@support_torch_compile(dynamic_arg_dims={"x": 0})
class Qwen3_VisionPatchMerger(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -275,6 +282,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.linear_fc1",
disable_tp=use_data_parallel,
return_bias=False,
)
self.act_fn = nn.GELU()
self.linear_fc2 = RowParallelLinear(
Expand All @@ -284,6 +292,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.linear_fc2",
disable_tp=use_data_parallel,
return_bias=False,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -292,9 +301,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
else:
x = self.norm(x).view(-1, self.hidden_size)

x_parallel, _ = self.linear_fc1(x)
x_parallel = self.linear_fc1(x)
x_parallel = self.act_fn(x_parallel)
out, _ = self.linear_fc2(x_parallel)
out = self.linear_fc2(x_parallel)
return out


Expand Down Expand Up @@ -325,45 +334,52 @@ def __init__(
self.out_hidden_size = vision_config.out_hidden_size * (
1 + len(self.deepstack_visual_indexes)
)

self.patch_embed = Qwen3_VisionPatchEmbed(
patch_size=self.patch_size,
temporal_patch_size=self.temporal_patch_size,
in_channels=vision_config.in_channels,
hidden_size=self.hidden_size,
)
# TODO[@lucaskabela]: Investigate fixing this usage
# see https://github.com/vllm-project/vllm/issues/27044
# DO NOT MOVE THIS IMPORT
from vllm.compilation.backends import set_model_tag

with set_model_tag("Qwen3_VisionPatchEmbed"):
self.patch_embed = Qwen3_VisionPatchEmbed(
patch_size=self.patch_size,
temporal_patch_size=self.temporal_patch_size,
in_channels=vision_config.in_channels,
hidden_size=self.hidden_size,
)

self.pos_embed = nn.Embedding(self.num_position_embeddings, self.hidden_size)

norm_layer = partial(nn.LayerNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
with set_model_tag("Qwen3_VisionPatchMerger"):
self.merger = Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)

self.deepstack_merger_list = nn.ModuleList(
[
Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size,
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)
with set_model_tag("Qwen3_VisionPatchMerger_postshuffle_norm"):
self.deepstack_merger_list = nn.ModuleList(
[
Qwen3_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
spatial_merge_size=self.spatial_merge_size,
use_postshuffle_norm=True,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.deepstack_merger_list.{layer_idx}",
use_data_parallel=use_data_parallel,
)
for layer_idx in range(len(self.deepstack_visual_indexes))
]
)

self.attn_backend = get_vit_attn_backend(
head_size=head_dim,
Expand All @@ -388,23 +404,24 @@ def __init__(
raise RuntimeError(
f"Qwen3-VL does not support {self.attn_backend} backend now."
)
self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
)
for layer_idx in range(vision_config.depth)
]
)
with set_model_tag("Qwen3_VisionBlock"):
self.blocks = nn.ModuleList(
[
Qwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend=self.attn_backend,
use_upstream_fa=use_upstream_fa,
)
for layer_idx in range(vision_config.depth)
]
)

@property
def dtype(self) -> torch.dtype:
Expand Down Expand Up @@ -1217,6 +1234,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
multimodal_config = vllm_config.model_config.multimodal_config

self.config = config
self.vllm_config = vllm_config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
if not multimodal_config.get_limit_per_prompt(
Expand Down Expand Up @@ -1362,12 +1380,13 @@ def _process_image_input(
image_embeds = image_input["image_embeds"].type(self.visual.dtype)
else:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)
with set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
)
else:
image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)

# Split concatenated embeddings for each image item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
Expand All @@ -1391,12 +1410,18 @@ def _process_video_input(
pixel_values_videos = video_input["pixel_values_videos"].type(
self.visual.dtype
)
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual, pixel_values_videos, grid_thw_list, rope_type="rope_3d"
)
else:
video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)
with set_forward_context(None, self.vllm_config):
if self.use_data_parallel:
return run_dp_sharded_mrope_vision_model(
self.visual,
pixel_values_videos,
grid_thw_list,
rope_type="rope_3d",
)
else:
video_embeds = self.visual(
pixel_values_videos, grid_thw=grid_thw_list
)

# Split concatenated embeddings for each video item.
# Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
multimodal_config = vllm_config.model_config.multimodal_config

self.config = config
self.vllm_config = vllm_config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

Expand Down