From d5540cd07005976a98427c98837d40837a6e9c16 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 30 Sep 2025 18:21:32 -0700 Subject: [PATCH 1/4] update Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 37 +++++++++++++++++--------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index ce92557d6424..a54440ffd905 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1126,13 +1126,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.visual = Qwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - ) + if multimodal_config.get_limit_per_prompt("image"): + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) + else: + self.visual = None self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config, prefix=maybe_prefix( @@ -1148,11 +1151,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): config.vision_config.deepstack_visual_indexes ) if self.use_deepstack else 0 # register buffer for deepstack - self.deepstack_input_embeds = [ - torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens, - config.text_config.hidden_size) - for _ in range(self.deepstack_num_level) - ] if self.use_deepstack else None + if self.use_deepstack and self.visual is not None: + self.deepstack_input_embeds = [ + torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size) + for _ in range(self.deepstack_num_level) + ] + else: + self.deepstack_input_embeds = None self.visual_dim = config.vision_config.out_hidden_size self.multiscale_dim = self.visual_dim * self.deepstack_num_level @@ -1526,7 +1533,11 @@ def compute_logits( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) + + skip_prefixes = [] + if self.visual is None: + skip_prefixes.extend(["visual."]) + loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) def get_mm_mapping(self) -> MultiModelKeys: From 45304614844c60ab572eda1accccd2a86ee2e7f6 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 30 Sep 2025 18:23:30 -0700 Subject: [PATCH 2/4] update Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index a54440ffd905..e82aa717a0d5 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1125,8 +1125,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - - if multimodal_config.get_limit_per_prompt("image"): + if not multimodal_config.get_limit_per_prompt("image") and \ + not multimodal_config.get_limit_per_prompt("video"): self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), From e0d4d891eb52d6d1f15e95d255c521c3a49c3df6 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 30 Sep 2025 18:26:23 -0700 Subject: [PATCH 3/4] Update vllm/model_executor/models/qwen3_vl.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index e82aa717a0d5..00de89811cc7 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1127,6 +1127,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" if not multimodal_config.get_limit_per_prompt("image") and \ not multimodal_config.get_limit_per_prompt("video"): + self.visual = None + else: self.visual = Qwen3_VisionTransformer( config.vision_config, norm_eps=getattr(config, "rms_norm_eps", 1e-6), @@ -1134,8 +1136,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): prefix=maybe_prefix(prefix, "visual"), use_data_parallel=self.use_data_parallel, ) - else: - self.visual = None self.language_model = Qwen3LLMForCausalLM(vllm_config=vllm_config, prefix=maybe_prefix( From ad04a4eea7cd8b7b54cf7de92f7a95ce2613a326 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Tue, 30 Sep 2025 19:08:02 -0700 Subject: [PATCH 4/4] add to moe Signed-off-by: Roger Wang --- vllm/model_executor/models/qwen3_vl_moe.py | 32 ++++++++++++++-------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl_moe.py b/vllm/model_executor/models/qwen3_vl_moe.py index 02cc5d6d66d1..1ed053eb2e96 100644 --- a/vllm/model_executor/models/qwen3_vl_moe.py +++ b/vllm/model_executor/models/qwen3_vl_moe.py @@ -319,13 +319,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" - self.visual = Qwen3_VisionTransformer( - config.vision_config, - norm_eps=getattr(config, "rms_norm_eps", 1e-6), - quant_config=quant_config, - prefix=maybe_prefix(prefix, "visual"), - use_data_parallel=self.use_data_parallel, - ) + if not multimodal_config.get_limit_per_prompt("image") and \ + not multimodal_config.get_limit_per_prompt("video"): + self.visual = None + else: + self.visual = Qwen3_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=quant_config, + prefix=maybe_prefix(prefix, "visual"), + use_data_parallel=self.use_data_parallel, + ) self.language_model = Qwen3MoeLLMForCausalLM(vllm_config=vllm_config, prefix=maybe_prefix( @@ -341,10 +345,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.vision_config.deepstack_visual_indexes ) if self.use_deepstack else 0 # register buffer for deepstack - self.deepstack_input_embeds = [ - torch.zeros(vllm_config.scheduler_config.max_num_batched_tokens, - config.text_config.hidden_size) - for _ in range(self.deepstack_num_level) - ] if self.use_deepstack else None + if self.use_deepstack and self.visual is not None: + self.deepstack_input_embeds = [ + torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.hidden_size) + for _ in range(self.deepstack_num_level) + ] + else: + self.deepstack_input_embeds = None self.visual_dim = config.vision_config.out_hidden_size self.multiscale_dim = self.visual_dim * self.deepstack_num_level