7979 AutoWeightsLoader ,
8080 PPMissingLayer ,
8181 WeightsMapper ,
82- flatten_bn ,
8382 make_empty_intermediate_tensors_factory ,
8483 maybe_prefix ,
8584)
@@ -347,54 +346,37 @@ def _get_prompt_updates(
347346
348347 def _get_mm_fields_config (
349348 self ,
350- hf_inputs ,
351- hf_processor_mm_kwargs ,
352- num_image_patches : torch .Tensor = None ,
353- ):
349+ hf_inputs : BatchFeature ,
350+ hf_processor_mm_kwargs : Mapping [str , object ],
351+ ) -> Mapping [str , MultiModalFieldConfig ]:
354352 # HF Processors always return a mask but vLLM doesn't need it
355353 hf_inputs .pop ("attention_mask" , None )
354+ num_image_patches = hf_inputs .get ("num_image_patches" )
356355 mm_fields = {
357356 key : MultiModalFieldConfig .flat_from_sizes ("image" , num_image_patches )
358357 for key in hf_inputs
359358 }
360359 mm_fields ["image_embeds" ] = MultiModalFieldConfig .flat_from_sizes (
361360 "image" , num_image_patches
362361 )
362+
363+ # Keep these as batched, as they always have batch size as first dim
364+ mm_fields ["image_grid_thw" ] = MultiModalFieldConfig .batched ("image" )
365+ mm_fields ["video_grid_thw" ] = MultiModalFieldConfig .batched ("image" )
363366 mm_fields ["num_image_patches" ] = MultiModalFieldConfig .batched ("image" )
364367 return mm_fields
365368
366- def _apply_hf_processor_text_mm (
369+ def _get_hf_mm_data (
367370 self ,
368- prompt_text : str ,
369371 mm_items : MultiModalDataItems ,
370- hf_processor_mm_kwargs : Mapping [str , object ],
371- tokenization_kwargs : Mapping [str , object ],
372- ) -> tuple [list [int ], BatchFeature , bool ]:
372+ ) -> tuple [Mapping [str , object ], Mapping [str , object ]]:
373373 """
374- Apply the HF processor on the prompt text and multi-modal data
375- together.
376-
377- In addition, return whether prompt replacements have been applied.
374+ In contrast to the base class, this method always adds
375+ `return_mm_token_type_ids` to the processor data
378376 """
379- processor_data , passthrough_data = self ._get_hf_mm_data (mm_items )
377+ processor_data , passthrough_data = super () ._get_hf_mm_data (mm_items )
380378 processor_data ["return_mm_token_type_ids" ] = True
381-
382- processed_data = self ._call_hf_processor (
383- prompt = prompt_text ,
384- mm_data = processor_data ,
385- mm_kwargs = hf_processor_mm_kwargs ,
386- tok_kwargs = tokenization_kwargs ,
387- )
388- processed_data .update (passthrough_data )
389-
390- (prompt_ids ,) = processed_data .pop ("input_ids" ).tolist ()
391- mm_token_type_ids = (
392- processed_data .pop ("mm_token_type_ids" )
393- if "mm_token_type_ids" in processed_data
394- else processed_data .pop ("token_type_ids" )
395- ) # for gemma3 only
396-
397- return prompt_ids , processed_data , mm_token_type_ids
379+ return processor_data , passthrough_data
398380
399381 def apply (
400382 self ,
@@ -421,18 +403,28 @@ def apply(
421403 # into string
422404 prompt = hf_processor .decode (prompt )
423405
424- (prompt_ids , processed_data , mm_token_type_ids ) = (
425- self ._apply_hf_processor_text_mm (
426- prompt_text = prompt ,
427- mm_items = mm_items ,
428- hf_processor_mm_kwargs = hf_processor_mm_kwargs ,
429- tokenization_kwargs = tokenization_kwargs ,
430- )
406+ # Bypass cached processor and always apply to the full set of mm inputs
407+ # NOTE: we can't just set caching=False because base class method
408+ # transforms outputs to `MultiModalKwargs` which is not going to
409+ # work for Transformers. We have a lot of logic tied to
410+ # `mm_tokens_per_modality` below
411+ prompt_ids , processed_data , _ = self ._apply_hf_processor_text_mm (
412+ prompt_text = prompt ,
413+ mm_items = mm_items ,
414+ hf_processor_mm_kwargs = hf_processor_mm_kwargs ,
415+ tokenization_kwargs = tokenization_kwargs ,
431416 )
432417
433- # HF processor will return `mm_token_type_ids` from which
434- # we can infer mm_placeholders. Until then hardcode to make code run
435- # Below tested on Llava. Prompts and `mm_token_type_ids` are always bs=1
418+ # For gemma3 we check `token_type_ids` as the key
419+ token_type_key = (
420+ "mm_token_type_ids"
421+ if "mm_token_type_ids" in processed_data
422+ else "token_type_ids"
423+ )
424+ mm_token_type_ids = processed_data .pop (token_type_key )
425+
426+ # We can infer vLLM style placeholder from token type ids, if we split
427+ # it for each input `mm_data`.
436428 mm_positions = torch .where (mm_token_type_ids == 1 )[1 ]
437429 images = mm_items .get_items ("image" , ImageProcessorItems )
438430 multimodal_config = self .info .ctx .model_config .multimodal_config
@@ -462,17 +454,12 @@ def apply(
462454 ]
463455 mm_placeholders = {"image" : ranges }
464456
465- num_image_patches = (
466- torch .tensor (mm_tokens_per_modality ["num_image_patches" ])
467- if "num_image_patches" in mm_tokens_per_modality
468- else None
457+ processed_data ["num_image_patches" ] = torch .tensor (
458+ mm_tokens_per_modality ["num_image_patches" ]
469459 )
470- processed_data ["num_image_patches" ] = num_image_patches
471460 mm_kwargs = MultiModalKwargsItems .from_hf_inputs (
472461 processed_data ,
473- self ._get_mm_fields_config (
474- processed_data , hf_processor_mm_kwargs , num_image_patches
475- ),
462+ self ._get_mm_fields_config (processed_data , hf_processor_mm_kwargs ),
476463 )
477464
478465 # Use overrides if provided; fallback to data-dependent hashing.
@@ -531,8 +518,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
531518 self .ignore_unexpected_suffixes .append (".bias" )
532519
533520 # Set correct attn and init on "meta" to delay allocating GPU tensors
534- # TODO: @raushan, use the public `model.set_attn_implementation()`
535- # method once its checks are fixed in Transformers.
536521 self .text_config ._attn_implementation = "vllm"
537522 with init_on_device_without_buffers ("meta" ):
538523 self .model : PreTrainedModel = AutoModel .from_config (
@@ -844,17 +829,6 @@ def compute_logits(
844829 return logits
845830
846831
847- def flatten_and_concat (x : list [torch .Tensor ]) -> torch .Tensor :
848- """Flatten until a list of tensors can be concatenated then do concat"""
849-
850- def _can_concat (x : list [torch .Tensor ]):
851- return len (set (map (lambda _x : _x .shape [1 :], x ))) == 1
852-
853- if _can_concat (x ):
854- return torch .concat (x )
855- return flatten_and_concat (flatten_bn (x ))
856-
857-
858832@MULTIMODAL_REGISTRY .register_processor (
859833 MultiModalProcessor ,
860834 info = MultiModalProcessingInfo ,
@@ -935,9 +909,6 @@ def get_multimodal_embeddings(self, **kwargs):
935909 vision_embeddings = self .model .get_image_features (pixel_values , ** kwargs )
936910
937911 if isinstance (vision_embeddings , torch .Tensor ):
938- if isinstance (num_image_patches , list ):
939- num_image_patches = torch .cat (num_image_patches )
940-
941912 if vision_embeddings .ndim == 2 :
942913 vision_embeddings = vision_embeddings .unsqueeze (0 )
943914
0 commit comments