3030 IsHybrid ,
3131 MultiModalEmbeddings ,
3232 SupportsMultiModal ,
33+ SupportsMultiModalPruning ,
3334)
3435from vllm .model_executor .models .internvl import (
3536 calculate_internvl_targets ,
4445 maybe_prefix ,
4546)
4647from vllm .multimodal import MULTIMODAL_REGISTRY
48+ from vllm .multimodal .evs import (
49+ compute_retained_tokens_count ,
50+ compute_retention_mask ,
51+ )
4752from vllm .multimodal .inputs import (
4853 MultiModalDataDict ,
4954 MultiModalFieldConfig ,
6267 PromptReplacement ,
6368 PromptUpdate ,
6469 PromptUpdateDetails ,
70+ _seq2tokens ,
6571)
6672from vllm .multimodal .profiling import BaseDummyInputsBuilder
6773from vllm .sequence import IntermediateTensors
6874from vllm .transformers_utils .configs .radio import RadioConfig
69- from vllm .transformers_utils .tokenizer import AnyTokenizer
75+ from vllm .transformers_utils .tokenizer import (
76+ AnyTokenizer ,
77+ cached_tokenizer_from_config ,
78+ encode_tokens ,
79+ )
7080from vllm .utils .tensor_schema import TensorSchema , TensorShape
7181
82+ from .utils import _merge_multimodal_embeddings
83+
7284# Configure PIL to handle large images without warnings
7385# This prevents DecompressionBombWarning for legitimate large images
7486Image .MAX_IMAGE_PIXELS = None # Disable the limit entirely
@@ -382,6 +394,7 @@ def __init__(
382394 max_dynamic_patch : Optional [int ] = None ,
383395 dynamic_image_size : Optional [bool ] = None ,
384396 video_token : Optional [str ] = None ,
397+ video_pruning_rate : Optional [float ] = None ,
385398 ) -> None :
386399 super ().__init__ (
387400 config = config ,
@@ -392,6 +405,7 @@ def __init__(
392405 )
393406 # add extra video token for video processing
394407 self .video_token = video_token
408+ self .video_pruning_rate = video_pruning_rate
395409
396410 @property
397411 def supports_video (self ) -> bool :
@@ -446,12 +460,38 @@ def _preprocess_video(
446460 ),
447461 }
448462
463+ image_size : int = self .config .force_image_size
464+ patch_size : int = self .config .patch_size
465+ downsample_ratio = self .config .downsample_ratio
466+ tokens_in_single_frame = int (
467+ (image_size * image_size // patch_size ** 2 ) * (downsample_ratio ** 2 )
468+ )
469+
449470 for pixel_values in pixel_values_lst_video :
450- num_patches = pixel_values .shape [0 ]
471+ num_frames = pixel_values .shape [0 ]
472+
473+ if (
474+ self .video_pruning_rate is not None
475+ and self .video_pruning_rate > 0.0
476+ ):
477+ # Start of EVS-specific code
478+ num_tokens = compute_retained_tokens_count (
479+ tokens_per_frame = tokens_in_single_frame ,
480+ num_frames = num_frames ,
481+ q = self .video_pruning_rate ,
482+ )
483+
484+ # Here we just need placeholders that won't actually be replaced -
485+ # we just need to make sure the total number of tokens is correct
486+ # assign all tokens to the first frame
487+ tokens_per_frame = [num_tokens ] + [0 ] * (num_frames - 1 )
488+
489+ # End of EVS-specific code
490+ else :
491+ tokens_per_frame = [tokens_in_single_frame ] * num_frames
492+
493+ video_repl = self .get_video_repl (tokens_per_frame , self .video_token )
451494
452- video_repl = self .get_video_repl (
453- self .num_image_token , num_patches , self .video_token
454- )
455495 text = [t .replace ("<video>" , video_repl .full , 1 ) for t in text ]
456496 return text , video_inputs
457497
@@ -501,20 +541,40 @@ def get_image_repl(
501541
502542 return PromptUpdateDetails .select_text (repl_full , IMG_CONTEXT )
503543
544+ @classmethod
504545 def get_video_repl (
505- self ,
506- feature_size : int ,
507- num_patches : Optional [int ] = None ,
546+ cls ,
547+ tokens_per_frame : list [int ],
508548 video_context_token : str = IMG_CONTEXT ,
509549 ) -> PromptUpdateDetails [str ]:
510- repl_features = video_context_token * self .num_image_token
511- repl_features_with_sep = IMG_START + repl_features + IMG_END
512- # num_patches is equal to num_frames
550+ """
551+ Build prompt replacement for a video.
552+ The replacement returned is not actually used to replace the placeholder
553+ tokens - it's just used to make sure we allocate the correct number
554+ of tokens.
555+ Actual replacement is done in get_multimodal_embeddings of
556+ NemotronH_Nano_VL_V2
557+ (specifically in _process_video_input -> _create_final_video_embeddings).
558+ There, we create the final embeddings with text embeddings for indicator tokens
559+ and video embeddings for video tokens.
560+ This is a single function that handles all cases - non EVS, EVS dummy, EVS real.
561+ The differentiation is done via tokens_per_frame parameter.
562+ - non EVS case - constant value same value across all frames
563+ - EVS dummy - Doesn't matter how tokens are distributed between frames - just
564+ make sure the total number of tokens is correct.
565+ - EVS real (called from get_real_video_repl_for_evs) - different value per frame
566+ Args:
567+ tokens_per_frame (list[int]): number of tokens per frame
568+ video_context_token (str): the token to use for the video context
569+ """
513570 repl_full = "" .join (
514- [f"Frame{ i + 1 } : { repl_features_with_sep } " for i in range (num_patches )]
571+ [
572+ f"Frame{ i + 1 } : { IMG_START } { video_context_token * num_tokens } { IMG_END } "
573+ for i , num_tokens in enumerate (tokens_per_frame )
574+ ]
515575 )
516576
517- return PromptUpdateDetails .select_text (repl_full , video_context_token )
577+ return PromptUpdateDetails .from_seq (repl_full )
518578
519579
520580class BaseNanoNemotronVLProcessingInfo (BaseProcessingInfo ):
@@ -605,6 +665,9 @@ def get_supported_mm_limits(self):
605665 def get_video_token (self ) -> Optional [str ]:
606666 return IMG_CONTEXT
607667
668+ def get_video_pruning_rate (self ) -> Optional [float ]:
669+ return self .ctx .get_mm_config ().video_pruning_rate
670+
608671 def get_num_frames_with_most_features (
609672 self ,
610673 seq_len : int ,
@@ -628,6 +691,7 @@ def get_hf_processor(self, **kwargs: object) -> NanoNemotronVLProcessor:
628691 config = self .get_hf_config (),
629692 tokenizer = self .get_tokenizer (),
630693 video_token = self .get_video_token (),
694+ video_pruning_rate = self .get_video_pruning_rate (),
631695 ** kwargs ,
632696 )
633697
@@ -805,8 +869,26 @@ def get_video_replacement_internvl(item_idx: int):
805869 if num_patches is not None :
806870 assert isinstance (num_patches , int )
807871
872+ video_pruning_rate = self .info .ctx .get_mm_config ().video_pruning_rate
873+ if video_pruning_rate is not None and video_pruning_rate > 0.0 :
874+ # Start of EVS-specific code
875+ num_tokens = compute_retained_tokens_count (
876+ tokens_per_frame = feature_size ,
877+ num_frames = num_patches ,
878+ q = video_pruning_rate ,
879+ )
880+ # Here we just need placeholders that won't actually be replaced -
881+ # we just need to make sure the total number of tokens is correct
882+ # assign all tokens to the first frame
883+ tokens_per_frame = [num_tokens ] + [0 ] * (num_patches - 1 )
884+
885+ # End of EVS-specific code
886+ else :
887+ tokens_per_frame = [feature_size ] * num_patches
888+
808889 return hf_processor .get_video_repl (
809- feature_size , num_patches , video_context_token = hf_processor .video_token
890+ tokens_per_frame ,
891+ video_context_token = hf_processor .video_token ,
810892 )
811893
812894 if self .info .supports_video :
@@ -901,7 +983,9 @@ def get_dummy_mm_data(
901983 info = NanoNemotronVLProcessingInfo ,
902984 dummy_inputs = NanoNemotronVLDummyInputsBuilder ,
903985)
904- class NemotronH_Nano_VL_V2 (nn .Module , HasInnerState , IsHybrid , SupportsMultiModal ):
986+ class NemotronH_Nano_VL_V2 (
987+ nn .Module , HasInnerState , IsHybrid , SupportsMultiModal , SupportsMultiModalPruning
988+ ):
905989 @classmethod
906990 def get_placeholder_str (cls , modality : str , i : int ) -> Optional [str ]:
907991 if modality .startswith ("image" ):
@@ -913,7 +997,7 @@ def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
913997 def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
914998 super ().__init__ ()
915999 config = vllm_config .model_config .hf_config
916-
1000+ multimodal_config = vllm_config . model_config . multimodal_config
9171001 image_size = config .force_image_size
9181002 patch_size = config .patch_size
9191003 self .patch_size = patch_size
@@ -924,7 +1008,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
9241008 self .downsample_ratio = config .downsample_ratio
9251009 self .ps_version = config .ps_version
9261010 self .image_tag_type = config .image_tag_type
927-
1011+ self . video_pruning_rate = multimodal_config . video_pruning_rate
9281012 self .language_model = init_vllm_registered_model (
9291013 vllm_config = vllm_config ,
9301014 hf_config = config .text_config ,
@@ -957,6 +1041,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
9571041 self .img_context_token_id = None
9581042 self .video_context_token_id = None
9591043 self .config = config
1044+ self .model_config = vllm_config .model_config
9601045
9611046 def pixel_shuffle (self , x , scale_factor = 0.5 ):
9621047 n , w , h , c = x .size ()
@@ -1049,7 +1134,7 @@ def _parse_and_validate_image_input(
10491134
10501135 def _process_image_input (
10511136 self , image_input : NanoNemotronVLImageInputs
1052- ) -> torch .Tensor :
1137+ ) -> tuple [ torch .Tensor , ...] :
10531138 if image_input ["type" ] == "image_embeds" :
10541139 return image_input ["data" ]
10551140
@@ -1071,6 +1156,109 @@ def _process_image_input(
10711156 ]
10721157 return image_embeds .split (image_feature_sizes )
10731158
1159+ def _process_video_input (
1160+ self , video_input : NanoNemotronVLVideoPixelInputs
1161+ ) -> tuple [torch .Tensor , ...]:
1162+ """Process video input and create final embeddings with video content
1163+ and indicator tokens."""
1164+ # Get video embeddings using the same processing as images
1165+ video_embeddings = self ._process_image_input (video_input )
1166+
1167+ final_video_embeddings : tuple [torch .Tensor , ...] = ()
1168+
1169+ image_rows = image_cols = self .config .force_image_size
1170+ downsample_ratio = self .config .downsample_ratio
1171+ patch_size = self .config .patch_size
1172+ rows = int (image_rows * downsample_ratio // patch_size )
1173+ cols = int (image_cols * downsample_ratio // patch_size )
1174+ video_pruning_rate = self .video_pruning_rate
1175+
1176+ # Calculate video feature dimensions (number of frames and
1177+ # their feature size (AKA tokens per frame))
1178+ # TODO: Maybe this can be optimized to avoid the loop?
1179+ for i , single_video_embeddings in enumerate (video_embeddings ):
1180+ num_frames = video_input ["num_patches" ][i ].item ()
1181+ assert single_video_embeddings .shape [0 ] % num_frames == 0
1182+
1183+ if video_pruning_rate is not None and video_pruning_rate > 0.0 :
1184+ # Start of EVS-specific code
1185+ retention_mask = compute_retention_mask (
1186+ single_video_embeddings ,
1187+ video_size_thw = (num_frames , rows , cols ),
1188+ spatial_merge_size = 1 ,
1189+ q = video_pruning_rate ,
1190+ )
1191+
1192+ # apply retention mask
1193+ single_video_embeddings = single_video_embeddings [retention_mask ]
1194+
1195+ # calculate the actual number of retained tokens per frame
1196+ retention_mask_thw = retention_mask .reshape (num_frames , rows , cols )
1197+ num_tokens_per_frame = (
1198+ retention_mask_thw .sum (dim = (1 , 2 )).long ().tolist ()
1199+ )
1200+ # End of EVS-specific code
1201+ else :
1202+ feature_size = single_video_embeddings .shape [0 ] // num_frames
1203+ num_tokens_per_frame = [feature_size ] * num_frames
1204+
1205+ final_video_embeddings += (
1206+ self ._create_final_video_embeddings (
1207+ single_video_embeddings ,
1208+ num_tokens_per_frame ,
1209+ ),
1210+ )
1211+
1212+ return final_video_embeddings
1213+
1214+ def _create_final_video_embeddings (
1215+ self ,
1216+ video_embeddings : torch .Tensor ,
1217+ num_tokens_per_frame : list [int ],
1218+ ) -> torch .Tensor :
1219+ """Create final embeddings that combine video embeddings with
1220+ text embeddings of indicator tokens.
1221+
1222+ These final embeddings contain:
1223+ - Actual video embeddings in positions corresponding to video content
1224+ - Text embeddings for indicator tokens (<img>, </img>, and
1225+ frame separation text) in their respective positions
1226+
1227+ These embeddings will replace the placeholder embeddings to create
1228+ input_embeds for the LLM.
1229+ """
1230+ device = video_embeddings .device
1231+
1232+ # Generate video replacement text and convert to token IDs
1233+ video_repl_text = NanoNemotronVLProcessor .get_video_repl (
1234+ num_tokens_per_frame ,
1235+ IMG_CONTEXT ,
1236+ ).full
1237+
1238+ tokenizer = cached_tokenizer_from_config (self .model_config )
1239+ repl_token_ids = torch .tensor (
1240+ _seq2tokens (tokenizer , video_repl_text ), device = device
1241+ )
1242+
1243+ # Get embedding token IDs for image context
1244+ embed_token_ids = torch .tensor (
1245+ encode_tokens (tokenizer , IMG_CONTEXT ), device = device
1246+ )
1247+
1248+ # Create mask for video embedding positions
1249+ is_video_embed = torch .isin (repl_token_ids , embed_token_ids )
1250+
1251+ # Create final video embeddings, merging text embeddings for indicator
1252+ # tokens with video embeddings
1253+ text_embeddings = self .get_language_model ().get_input_embeddings (repl_token_ids )
1254+ final_video_embeddings = _merge_multimodal_embeddings (
1255+ inputs_embeds = text_embeddings ,
1256+ multimodal_embeddings = video_embeddings ,
1257+ is_multimodal = is_video_embed ,
1258+ )
1259+
1260+ return final_video_embeddings
1261+
10741262 def _parse_and_validate_video_input (
10751263 self , ** kwargs : object
10761264 ) -> Optional [NanoNemotronVLVideoPixelInputs ]:
@@ -1152,7 +1340,7 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
11521340 multimodal_embeddings += vision_embeddings
11531341 if modality == "videos" :
11541342 video_input = modalities ["videos" ]
1155- video_embeddings = self ._process_image_input (video_input )
1343+ video_embeddings = self ._process_video_input (video_input )
11561344 multimodal_embeddings += video_embeddings
11571345
11581346 return multimodal_embeddings
0 commit comments