@@ -1125,14 +1125,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
11251125 self .config = config
11261126 self .multimodal_config = multimodal_config
11271127 self .use_data_parallel = multimodal_config .mm_encoder_tp_mode == "data"
1128-
1129- self .visual = Qwen3_VisionTransformer (
1130- config .vision_config ,
1131- norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
1132- quant_config = quant_config ,
1133- prefix = maybe_prefix (prefix , "visual" ),
1134- use_data_parallel = self .use_data_parallel ,
1135- )
1128+ if not multimodal_config .get_limit_per_prompt ("image" ) and \
1129+ not multimodal_config .get_limit_per_prompt ("video" ):
1130+ self .visual = None
1131+ else :
1132+ self .visual = Qwen3_VisionTransformer (
1133+ config .vision_config ,
1134+ norm_eps = getattr (config , "rms_norm_eps" , 1e-6 ),
1135+ quant_config = quant_config ,
1136+ prefix = maybe_prefix (prefix , "visual" ),
1137+ use_data_parallel = self .use_data_parallel ,
1138+ )
11361139
11371140 self .language_model = Qwen3LLMForCausalLM (vllm_config = vllm_config ,
11381141 prefix = maybe_prefix (
@@ -1148,11 +1151,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
11481151 config .vision_config .deepstack_visual_indexes
11491152 ) if self .use_deepstack else 0
11501153 # register buffer for deepstack
1151- self .deepstack_input_embeds = [
1152- torch .zeros (vllm_config .scheduler_config .max_num_batched_tokens ,
1153- config .text_config .hidden_size )
1154- for _ in range (self .deepstack_num_level )
1155- ] if self .use_deepstack else None
1154+ if self .use_deepstack and self .visual is not None :
1155+ self .deepstack_input_embeds = [
1156+ torch .zeros (
1157+ vllm_config .scheduler_config .max_num_batched_tokens ,
1158+ config .text_config .hidden_size )
1159+ for _ in range (self .deepstack_num_level )
1160+ ]
1161+ else :
1162+ self .deepstack_input_embeds = None
11561163 self .visual_dim = config .vision_config .out_hidden_size
11571164 self .multiscale_dim = self .visual_dim * self .deepstack_num_level
11581165
@@ -1526,7 +1533,11 @@ def compute_logits(
15261533
15271534 def load_weights (self , weights : Iterable [tuple [str ,
15281535 torch .Tensor ]]) -> set [str ]:
1529- loader = AutoWeightsLoader (self )
1536+
1537+ skip_prefixes = []
1538+ if self .visual is None :
1539+ skip_prefixes .extend (["visual." ])
1540+ loader = AutoWeightsLoader (self , skip_prefixes = skip_prefixes )
15301541 return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
15311542
15321543 def get_mm_mapping (self ) -> MultiModelKeys :
0 commit comments