11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33from collections .abc import Iterable , Mapping , Sequence
4- from typing import Any , Literal , Optional , TypedDict , Union , cast
4+ from typing import Annotated , Any , Literal , Optional , Union , cast
55
66import numpy as np
77import torch
4141# yapf: enable
4242from vllm .multimodal .profiling import BaseDummyInputsBuilder
4343from vllm .sequence import IntermediateTensors
44+ from vllm .utils .tensor_schema import TensorSchema , TensorShape
4445
4546from .interfaces import (MultiModalEmbeddings , SupportsMultiModal ,
4647 SupportsTranscription )
5455TOKENS_PER_AUDIO = 188
5556
5657
57- class Gemma3nImagePixelInputs (TypedDict ):
58- pixel_values : torch .Tensor
59- """Shape: `(batch_size * num_images, num_channels, height, width)`"""
58+ class Gemma3nImagePixelInputs (TensorSchema ):
59+ """
60+ Dimensions:
61+ - bn: Batch size * number of images
62+ - c: Number of channels (3)
63+ - h: Height of each patch
64+ - w: Width of each patch
65+ """
66+ type : Literal ["pixel_values" ] = "pixel_values"
67+ pixel_values : Annotated [torch .Tensor , TensorShape ("bn" , 3 , "h" , "w" )]
6068
6169
62- class Gemma3nAudioInputs (TypedDict ):
63- input_features : Union [torch .Tensor , list [torch .Tensor ]]
64- input_features_padded : torch .Tensor
65- """Shape: `(batch_size * num_audio, seq_length, num_features)`"""
66- input_features_mask : torch .Tensor
67- """Shape: `(batch_size * num_audio, seq_length)`"""
70+ class Gemma3nAudioInputs (TensorSchema ):
71+ """
72+ Dimensions:
73+ - bn: Batch size * number of audios
74+ - s: seq_length
75+ - f: num_features
76+ """
77+ type : Literal ["audio" ] = "audio"
78+ input_features_padded : Annotated [torch .Tensor , TensorShape ("bn" , "s" , "f" )]
79+ input_features_mask : Annotated [torch .Tensor , TensorShape ("bn" , "s" )]
6880
6981
7082Gemma3nImageInputs = Gemma3nImagePixelInputs
@@ -212,9 +224,9 @@ def _get_mm_fields_config(
212224
213225 return dict (
214226 pixel_values = MultiModalFieldConfig .batched ("image" ),
215- input_features = MultiModalFieldConfig .batched ("audio" ),
216227 input_features_padded = MultiModalFieldConfig .batched ("audio" ),
217- input_features_mask = MultiModalFieldConfig .batched ("audio" ))
228+ input_features_mask = MultiModalFieldConfig .batched ("audio" ),
229+ )
218230
219231 def _get_prompt_updates (
220232 self ,
@@ -422,6 +434,7 @@ def forward(
422434 dummy_inputs = Gemma3nDummyInputsBuilder )
423435class Gemma3nForConditionalGeneration (nn .Module , SupportsMultiModal ,
424436 SupportsTranscription ):
437+ merge_by_field_config = True
425438 supported_languages = ISO639_1_SUPPORTED_LANGS
426439
427440 packed_modules_mapping = {
@@ -482,14 +495,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
482495 device = self .language_model .model .embed_tokens .weight .device ,
483496 dtype = self .language_model .model .embed_tokens .weight .dtype )
484497
485- @property
486- def dtype (self ):
487- return next (self .parameters ()).dtype
488-
489- def _validate_pixel_values (self , data : torch .Tensor ) -> torch .Tensor :
490- # TODO check if there are any
491- return data
492-
493498 def _parse_and_validate_image_input (
494499 self , ** kwargs : object ) -> Optional [Gemma3nImageInputs ]:
495500 pixel_values = kwargs .pop ("pixel_values" , None )
@@ -499,34 +504,22 @@ def _parse_and_validate_image_input(
499504 if pixel_values is None :
500505 return None
501506
502- if not isinstance (pixel_values , (torch .Tensor , list )):
503- raise ValueError ("Incorrect type of pixel values. "
504- f"Got type: { type (pixel_values )} " )
505-
506- pixel_values = flatten_bn (pixel_values , concat = True )
507- pixel_values = pixel_values .contiguous ()
508-
509- return Gemma3nImagePixelInputs (
510- pixel_values = self ._validate_pixel_values (pixel_values ), )
507+ return Gemma3nImagePixelInputs (pixel_values = pixel_values )
511508
512509 def _parse_and_validate_audio_input (
513510 self , ** kwargs : object ) -> Optional [Gemma3nAudioInputs ]:
514- input_features = kwargs .pop ("input_features" , None )
515- if input_features is None :
511+
512+ input_features_padded = kwargs .pop ("input_features_padded" , None )
513+ if input_features_padded is None :
516514 return None
517515
518516 input_features_mask = kwargs .pop ("input_features_mask" , None )
519517 if input_features_mask is None :
520518 return None
521519
522- input_features_padded = kwargs .pop ("input_features_padded" , None )
523- if input_features_padded is None :
524- return None
525-
526520 return Gemma3nAudioInputs (
527- input_features = input_features ,
528- input_features_mask = input_features_mask ,
529521 input_features_padded = input_features_padded ,
522+ input_features_mask = input_features_mask ,
530523 )
531524
532525 def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
@@ -539,7 +532,7 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
539532 ) and "image" not in mm_input_by_modality :
540533 mm_input_by_modality [
541534 "image" ] = self ._parse_and_validate_image_input (** kwargs )
542- if input_key == "input_features " \
535+ if input_key == "input_features_padded " \
543536 and "audio" not in mm_input_by_modality :
544537 mm_input_by_modality [
545538 "audio" ] = self ._parse_and_validate_audio_input (** kwargs )
0 commit comments