1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- from typing import Any , Dict
16+ import logging
17+ from typing import Any , Dict , Tuple
1718
1819import torch
1920from transformers import AutoConfig
2223from vllm .utils import get_distributed_init_method , get_ip , get_open_port
2324from vllm .worker .worker import Worker
2425
26+ logger = logging .getLogger (__name__ )
27+
2528
2629def load_vision_model (model_id : str ) -> torch .nn .Module :
2730 """
@@ -44,13 +47,24 @@ def load_vision_model(model_id: str) -> torch.nn.Module:
4447 return worker .model_runner .model
4548
4649
47- def get_vision_embeddings_size (model_id : str , num_patches : int ) -> tuple [int , int , int ]:
48- """Calculate vision embeddings size using model config and image processor
49- Returns a tuple of (batch_size, num_patches, hidden_dim).
50+ def get_vision_embeddings_info (
51+ model_id : str , num_patches : int
52+ ) -> Tuple [Tuple [int , int , int ], torch .dtype ]:
53+ """Calculate vision embeddings size and dtype using model config
54+ Returns a tuple of (batch_size, num_patches, hidden_dim), dtype.
5055 """
5156 config = AutoConfig .from_pretrained (model_id , trust_remote_code = True )
5257 assert num_patches > 0 , "Number of patches must be positive"
53- return 1 , num_patches , getattr (config , "hidden_size" , 4096 )
58+ if not hasattr (config , "torch_dtype" ):
59+ raise ValueError ("Model config missing required 'torch_dtype' attribute" )
60+ if not hasattr (config , "hidden_size" ):
61+ logger .warning (
62+ "Model config missing required 'hidden_size' attribute, using 4096"
63+ )
64+ hidden_size = 4096
65+ else :
66+ hidden_size = config .hidden_size
67+ return (1 , num_patches , hidden_size ), config .torch_dtype
5468
5569
5670def construct_mm_data (
@@ -60,8 +74,8 @@ def construct_mm_data(
6074 if "Qwen2" in model :
6175 return {
6276 "image" : {
63- "image_embeds" : image_embeds .squeeze (0 ),
64- "image_grid_thw" : torch .tensor (encode_output .image_grid_thw ),
77+ "image_embeds" : image_embeds .squeeze (0 ). to ( torch . float16 ) ,
78+ "image_grid_thw" : torch .tensor (encode_output .image_grid_thw ). squeeze ( 0 ) ,
6579 }
6680 }
6781 elif "MiniCPM-V" in model :
0 commit comments