5555from vllm .multimodal import MULTIMODAL_REGISTRY
5656from vllm .multimodal .inputs import (ImageItem , ModalityData ,
5757 MultiModalFieldConfig , MultiModalKwargs ,
58- NestedTensors , VideoItem )
58+ VideoItem )
5959from vllm .multimodal .parse import (ImageSize , ModalityDataItems ,
6060 MultiModalDataItems , MultiModalDataParser )
6161from vllm .multimodal .processing import (BaseMultiModalProcessor ,
@@ -1233,7 +1233,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
12331233 return modalities
12341234
12351235 def get_multimodal_embeddings (
1236- self , ** kwargs ) -> Optional [List [ Tuple [ NestedTensors , str ] ]]:
1236+ self , ** kwargs ) -> Optional [tuple [ torch . Tensor , ... ]]:
12371237
12381238 modalities = self ._parse_and_validate_multimodal_inputs (** kwargs )
12391239 if not modalities :
@@ -1260,8 +1260,7 @@ def get_multimodal_embeddings(
12601260 def get_input_embeddings (
12611261 self ,
12621262 input_ids : torch .Tensor ,
1263- multimodal_embeddings : Optional [List [Tuple [NestedTensors ,
1264- str ]]] = None ,
1263+ multimodal_embeddings : Optional [tuple [torch .Tensor , ...]] = None ,
12651264 ) -> torch .Tensor :
12661265 inputs_embeds = self .language_model .get_input_embeddings (input_ids )
12671266 if multimodal_embeddings is not None :
@@ -1270,6 +1269,33 @@ def get_input_embeddings(
12701269 [self .config .image_token_id , self .config .video_token_id ])
12711270 return inputs_embeds
12721271
1272+ def get_input_embeddings_v0 (
1273+ self ,
1274+ input_ids : torch .Tensor ,
1275+ image_input : Optional [tuple [torch .Tensor , ...]] = None ,
1276+ video_input : Optional [tuple [torch .Tensor , ...]] = None ,
1277+ ) -> torch .Tensor :
1278+
1279+ inputs_embeds = self .get_input_embeddings (input_ids )
1280+ if image_input is not None :
1281+ image_embeds = self ._process_image_input (image_input )
1282+ inputs_embeds = merge_multimodal_embeddings (
1283+ input_ids ,
1284+ inputs_embeds ,
1285+ image_embeds ,
1286+ placeholder_token_id = self .config .image_token_id ,
1287+ )
1288+
1289+ if video_input is not None :
1290+ video_embeds = self ._process_video_input (video_input )
1291+ inputs_embeds = merge_multimodal_embeddings (
1292+ input_ids ,
1293+ inputs_embeds ,
1294+ video_embeds ,
1295+ placeholder_token_id = self .config .video_token_id ,
1296+ )
1297+ return inputs_embeds
1298+
12731299 def forward (
12741300 self ,
12751301 input_ids : torch .Tensor ,
@@ -1303,22 +1329,25 @@ def forward(
13031329 if intermediate_tensors is not None :
13041330 inputs_embeds = None
13051331
1306- # NOTE: In v1, inputs_embeds is always generated at model runner, this
1307- # condition is for v0 compatibility.
1332+ # NOTE: In v1, inputs_embeds is always generated at model runner from
1333+ # `get_multimodal_embeddings` and `get_input_embeddings`, this
1334+ # condition is only for v0 compatibility.
13081335 elif inputs_embeds is None :
1309- multimodal_embeddings = self .get_multimodal_embeddings (** kwargs )
1310-
1311- # We need to check for usage of mrope here in case there is
1312- # multimodal data.
1313- # TODO (ywang96): move this to model runner in V1.
1314- if multimodal_embeddings is not None and uses_mrope (self .config ):
1315- assert positions .ndim == 2 and positions .size (0 ) == 3 , (
1316- "multimodal section rotary embedding requires "
1317- f"(3, seq_len) positions, but got { positions .size ()} " )
1318-
1319- inputs_embeds = self .get_input_embeddings (input_ids ,
1320- multimodal_embeddings )
1321- input_ids = None
1336+ image_input = self ._parse_and_validate_image_input (** kwargs )
1337+ video_input = self ._parse_and_validate_video_input (** kwargs )
1338+
1339+ if image_input is None and video_input is None :
1340+ inputs_embeds = None
1341+ else :
1342+ if uses_mrope (self .config ):
1343+ assert positions .ndim == 2 and positions .size (0 ) == 3 , (
1344+ "multimodal section rotary embedding requires "
1345+ f"(3, seq_len) positions, but got { positions .size ()} " )
1346+ inputs_embeds = self .get_input_embeddings_v0 (
1347+ input_ids ,
1348+ image_input = image_input ,
1349+ video_input = video_input )
1350+ input_ids = None
13221351
13231352 hidden_states = self .language_model .model (
13241353 input_ids = input_ids ,
0 commit comments