@@ -358,6 +358,7 @@ def _init_model_kwargs(self, num_tokens: int):
358358 if num_pooling_reqs == 0 :
359359 return model_kwargs
360360
361+ # This does nontrivial work.
361362 pooling_params = self .input_batch .pooling_metadata .pooling_params
362363
363364 assert num_pooling_reqs == num_reqs
@@ -465,7 +466,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
465466 for req_id in unscheduled_req_ids :
466467 self .input_batch .remove_request (req_id )
467468
468- req_ids_to_add : list [str ] = []
469+ reqs_to_add : list [CachedRequestState ] = []
469470 # Add new requests to the cached states.
470471 for new_req_data in scheduler_output .scheduled_new_reqs :
471472 req_id = new_req_data .req_id
@@ -480,14 +481,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
480481 generator = None
481482
482483 if pooling_params :
483- assert ( task : = pooling_params .task ) is not None , (
484- "You did not set `task` in the API" )
484+ task = pooling_params .task
485+ assert task is not None , "You did not set `task` in the API"
485486
486487 model = cast (VllmModelForPooling , self .get_model ())
487488 to_update = model .pooler .get_pooling_updates (task )
488489 to_update .apply (pooling_params )
489490
490- self . requests [ req_id ] = CachedRequestState (
491+ req_state = CachedRequestState (
491492 req_id = req_id ,
492493 prompt_token_ids = new_req_data .prompt_token_ids ,
493494 mm_kwargs = new_req_data .mm_kwargs ,
@@ -501,36 +502,34 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
501502 lora_request = new_req_data .lora_request ,
502503 )
503504
505+ self .requests [req_id ] = req_state
506+
504507 # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
505508 if self .uses_mrope :
506509 image_grid_thw = []
507510 video_grid_thw = []
508511 second_per_grid_ts = []
509512 audio_feature_lengths = []
510513 use_audio_in_video = False
511- for mm_item in self . requests [ req_id ] .mm_kwargs :
514+ for mm_item in req_state .mm_kwargs :
512515 mm_input = mm_item .get_data ()
513- if mm_input .get ("image_grid_thw" ) is not None :
514- image_grid_thw .append (
515- mm_input ["image_grid_thw" ].tolist ())
516- if mm_input .get ("video_grid_thw" ) is not None :
517- video_grid_thw .append (
518- mm_input ["video_grid_thw" ].tolist ())
519- if mm_input .get ("second_per_grid_ts" ) is not None :
520- second_per_grid_ts .append (
521- mm_input ["second_per_grid_ts" ])
522- if mm_input .get ("audio_feature_lengths" ) is not None :
523- audio_feature_lengths .append (
524- mm_input ["audio_feature_lengths" ])
516+ if (t := mm_input .get ("image_grid_thw" )) is not None :
517+ image_grid_thw .append (t .tolist ())
518+ if (t := mm_input .get ("video_grid_thw" )) is not None :
519+ video_grid_thw .append (t .tolist ())
520+ if (t := mm_input .get ("second_per_grid_ts" )) is not None :
521+ second_per_grid_ts .append (t )
522+ if (t :=
523+ mm_input .get ("audio_feature_lengths" )) is not None :
524+ audio_feature_lengths .append (t )
525525 if mm_input .get ("use_audio_in_video" ) is True :
526526 use_audio_in_video = True
527527
528528 hf_config = self .model_config .hf_config
529529
530- self .requests [req_id ].mrope_positions , \
531- self .requests [req_id ].mrope_position_delta = \
530+ req_state .mrope_positions , req_state .mrope_position_delta = \
532531 MRotaryEmbedding .get_input_positions_tensor (
533- self . requests [ req_id ] .prompt_token_ids ,
532+ req_state .prompt_token_ids ,
534533 hf_config = hf_config ,
535534 image_grid_thw = image_grid_thw ,
536535 video_grid_thw = video_grid_thw ,
@@ -539,7 +538,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
539538 use_audio_in_video = use_audio_in_video ,
540539 )
541540
542- req_ids_to_add .append (req_id )
541+ reqs_to_add .append (req_state )
543542
544543 # Update the states of the running/resumed requests.
545544 is_last_rank = get_pp_group ().is_last_rank
@@ -587,7 +586,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
587586 # The request is not in the persistent batch.
588587 # The request was either preempted and resumed later, or was not
589588 # scheduled in the previous step and needs to be added again.
590- req_ids_to_add .append (req_id )
589+ reqs_to_add .append (req_state )
591590 continue
592591
593592 # Update the persistent batch.
@@ -624,9 +623,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
624623
625624 # Add the new or resumed requests to the persistent batch.
626625 # The smaller empty indices are filled first.
627- for req_id in req_ids_to_add :
628- req_state = self .requests [req_id ]
629- self .input_batch .add_request (req_state )
626+ for request in reqs_to_add :
627+ self .input_batch .add_request (request )
630628
631629 # Condense the batched states if there are gaps left by removed requests
632630 self .input_batch .condense ()
@@ -639,38 +637,32 @@ def _extract_mm_kwargs(
639637 self ,
640638 scheduler_output : "SchedulerOutput" ,
641639 ) -> BatchedTensorInputs :
642- if self .is_multimodal_raw_input_supported : # noqa: SIM102
643- if scheduler_output :
644- mm_kwargs = list [MultiModalKwargsItem ]()
645- for req in scheduler_output .scheduled_new_reqs :
646- req_mm_kwargs = req .mm_kwargs
647- if not isinstance (req_mm_kwargs , list ):
648- req_mm_kwargs = list (req_mm_kwargs )
649- mm_kwargs .extend (req_mm_kwargs )
650-
651- # Input all modalities at once
652- mm_kwargs_combined : BatchedTensorInputs = {}
653- for _ , _ , mm_kwargs_group in group_mm_kwargs_by_modality (
654- mm_kwargs ,
655- device = self .device ,
656- pin_memory = self .pin_memory ,
657- ):
658- mm_kwargs_combined .update (mm_kwargs_group )
659-
660- return mm_kwargs_combined
640+ if not self .is_multimodal_raw_input_supported or not scheduler_output : # noqa: SIM102
641+ return {}
661642
662- return {}
643+ mm_kwargs = list [MultiModalKwargsItem ]()
644+ for req in scheduler_output .scheduled_new_reqs :
645+ mm_kwargs .extend (req .mm_kwargs )
663646
664- def _dummy_mm_kwargs (self , num_seqs : int ) -> BatchedTensorInputs :
665- if self .is_multimodal_raw_input_supported :
666- mm_budget = self .mm_budget
667- assert mm_budget is not None
647+ # Input all modalities at once
648+ mm_kwargs_combined : BatchedTensorInputs = {}
649+ for _ , _ , mm_kwargs_group in group_mm_kwargs_by_modality (
650+ mm_kwargs ,
651+ device = self .device ,
652+ pin_memory = self .pin_memory ,
653+ ):
654+ mm_kwargs_combined .update (mm_kwargs_group )
668655
669- dummy_modality = mm_budget . get_modality_with_max_tokens ()
656+ return mm_kwargs_combined
670657
671- return self ._get_mm_dummy_batch (dummy_modality , num_seqs )
658+ def _dummy_mm_kwargs (self , num_seqs : int ) -> BatchedTensorInputs :
659+ if not self .is_multimodal_raw_input_supported :
660+ return {}
661+ mm_budget = self .mm_budget
662+ assert mm_budget is not None
672663
673- return {}
664+ dummy_modality = mm_budget .get_modality_with_max_tokens ()
665+ return self ._get_mm_dummy_batch (dummy_modality , num_seqs )
674666
675667 def _get_cumsum_and_arange (
676668 self ,
@@ -1612,6 +1604,7 @@ def execute_model(
16121604 batch_descriptor = batch_descriptor ,
16131605 ), self .maybe_get_kv_connector_output (
16141606 scheduler_output ) as kv_connector_output :
1607+
16151608 model_output = self .model (
16161609 input_ids = input_ids ,
16171610 positions = positions ,
0 commit comments