Skip to content

Commit 3ddc4c1

Browse files
njhillepwalsh
authored andcommitted
[Misc] Misc code cleanup/simplification (vllm-project#23304)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent 5b37d7f commit 3ddc4c1

File tree

4 files changed

+52
-58
lines changed

4 files changed

+52
-58
lines changed

vllm/v1/sample/sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def forward(
9191
logits = self.apply_bad_words(logits, sampling_metadata)
9292

9393
# Apply logits processors which can impact greedy sampling
94-
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
94+
for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
9595
logits = processor.apply(logits)
9696

9797
# Apply penalties (e.g., min_tokens, freq_penalties).

vllm/v1/worker/gpu_input_batch.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,10 +442,11 @@ def remove_request(self, req_id: str) -> Optional[int]:
442442
# LoRA
443443
lora_id = self.request_lora_mapping[req_index]
444444
if lora_id != 0:
445-
self.lora_id_to_request_ids[lora_id].discard(req_id)
446-
if len(self.lora_id_to_request_ids[lora_id]) == 0:
447-
self.lora_id_to_request_ids.pop(lora_id)
448-
self.lora_id_to_lora_request.pop(lora_id)
445+
lora_req_ids = self.lora_id_to_request_ids[lora_id]
446+
lora_req_ids.discard(req_id)
447+
if not lora_req_ids:
448+
del self.lora_id_to_request_ids[lora_id]
449+
del self.lora_id_to_lora_request[lora_id]
449450
self.request_lora_mapping[req_index] = 0
450451

451452
self.has_allowed_token_ids.discard(req_id)

vllm/v1/worker/gpu_model_runner.py

Lines changed: 45 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def _init_model_kwargs(self, num_tokens: int):
358358
if num_pooling_reqs == 0:
359359
return model_kwargs
360360

361+
# This does nontrivial work.
361362
pooling_params = self.input_batch.pooling_metadata.pooling_params
362363

363364
assert num_pooling_reqs == num_reqs
@@ -465,7 +466,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
465466
for req_id in unscheduled_req_ids:
466467
self.input_batch.remove_request(req_id)
467468

468-
req_ids_to_add: list[str] = []
469+
reqs_to_add: list[CachedRequestState] = []
469470
# Add new requests to the cached states.
470471
for new_req_data in scheduler_output.scheduled_new_reqs:
471472
req_id = new_req_data.req_id
@@ -480,14 +481,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
480481
generator = None
481482

482483
if pooling_params:
483-
assert (task := pooling_params.task) is not None, (
484-
"You did not set `task` in the API")
484+
task = pooling_params.task
485+
assert task is not None, "You did not set `task` in the API"
485486

486487
model = cast(VllmModelForPooling, self.get_model())
487488
to_update = model.pooler.get_pooling_updates(task)
488489
to_update.apply(pooling_params)
489490

490-
self.requests[req_id] = CachedRequestState(
491+
req_state = CachedRequestState(
491492
req_id=req_id,
492493
prompt_token_ids=new_req_data.prompt_token_ids,
493494
mm_kwargs=new_req_data.mm_kwargs,
@@ -501,36 +502,34 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
501502
lora_request=new_req_data.lora_request,
502503
)
503504

505+
self.requests[req_id] = req_state
506+
504507
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
505508
if self.uses_mrope:
506509
image_grid_thw = []
507510
video_grid_thw = []
508511
second_per_grid_ts = []
509512
audio_feature_lengths = []
510513
use_audio_in_video = False
511-
for mm_item in self.requests[req_id].mm_kwargs:
514+
for mm_item in req_state.mm_kwargs:
512515
mm_input = mm_item.get_data()
513-
if mm_input.get("image_grid_thw") is not None:
514-
image_grid_thw.append(
515-
mm_input["image_grid_thw"].tolist())
516-
if mm_input.get("video_grid_thw") is not None:
517-
video_grid_thw.append(
518-
mm_input["video_grid_thw"].tolist())
519-
if mm_input.get("second_per_grid_ts") is not None:
520-
second_per_grid_ts.append(
521-
mm_input["second_per_grid_ts"])
522-
if mm_input.get("audio_feature_lengths") is not None:
523-
audio_feature_lengths.append(
524-
mm_input["audio_feature_lengths"])
516+
if (t := mm_input.get("image_grid_thw")) is not None:
517+
image_grid_thw.append(t.tolist())
518+
if (t := mm_input.get("video_grid_thw")) is not None:
519+
video_grid_thw.append(t.tolist())
520+
if (t := mm_input.get("second_per_grid_ts")) is not None:
521+
second_per_grid_ts.append(t)
522+
if (t :=
523+
mm_input.get("audio_feature_lengths")) is not None:
524+
audio_feature_lengths.append(t)
525525
if mm_input.get("use_audio_in_video") is True:
526526
use_audio_in_video = True
527527

528528
hf_config = self.model_config.hf_config
529529

530-
self.requests[req_id].mrope_positions, \
531-
self.requests[req_id].mrope_position_delta = \
530+
req_state.mrope_positions, req_state.mrope_position_delta = \
532531
MRotaryEmbedding.get_input_positions_tensor(
533-
self.requests[req_id].prompt_token_ids,
532+
req_state.prompt_token_ids,
534533
hf_config=hf_config,
535534
image_grid_thw=image_grid_thw,
536535
video_grid_thw=video_grid_thw,
@@ -539,7 +538,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
539538
use_audio_in_video=use_audio_in_video,
540539
)
541540

542-
req_ids_to_add.append(req_id)
541+
reqs_to_add.append(req_state)
543542

544543
# Update the states of the running/resumed requests.
545544
is_last_rank = get_pp_group().is_last_rank
@@ -587,7 +586,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
587586
# The request is not in the persistent batch.
588587
# The request was either preempted and resumed later, or was not
589588
# scheduled in the previous step and needs to be added again.
590-
req_ids_to_add.append(req_id)
589+
reqs_to_add.append(req_state)
591590
continue
592591

593592
# Update the persistent batch.
@@ -624,9 +623,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
624623

625624
# Add the new or resumed requests to the persistent batch.
626625
# The smaller empty indices are filled first.
627-
for req_id in req_ids_to_add:
628-
req_state = self.requests[req_id]
629-
self.input_batch.add_request(req_state)
626+
for request in reqs_to_add:
627+
self.input_batch.add_request(request)
630628

631629
# Condense the batched states if there are gaps left by removed requests
632630
self.input_batch.condense()
@@ -639,38 +637,32 @@ def _extract_mm_kwargs(
639637
self,
640638
scheduler_output: "SchedulerOutput",
641639
) -> BatchedTensorInputs:
642-
if self.is_multimodal_raw_input_supported: # noqa: SIM102
643-
if scheduler_output:
644-
mm_kwargs = list[MultiModalKwargsItem]()
645-
for req in scheduler_output.scheduled_new_reqs:
646-
req_mm_kwargs = req.mm_kwargs
647-
if not isinstance(req_mm_kwargs, list):
648-
req_mm_kwargs = list(req_mm_kwargs)
649-
mm_kwargs.extend(req_mm_kwargs)
650-
651-
# Input all modalities at once
652-
mm_kwargs_combined: BatchedTensorInputs = {}
653-
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
654-
mm_kwargs,
655-
device=self.device,
656-
pin_memory=self.pin_memory,
657-
):
658-
mm_kwargs_combined.update(mm_kwargs_group)
659-
660-
return mm_kwargs_combined
640+
if not self.is_multimodal_raw_input_supported or not scheduler_output: # noqa: SIM102
641+
return {}
661642

662-
return {}
643+
mm_kwargs = list[MultiModalKwargsItem]()
644+
for req in scheduler_output.scheduled_new_reqs:
645+
mm_kwargs.extend(req.mm_kwargs)
663646

664-
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
665-
if self.is_multimodal_raw_input_supported:
666-
mm_budget = self.mm_budget
667-
assert mm_budget is not None
647+
# Input all modalities at once
648+
mm_kwargs_combined: BatchedTensorInputs = {}
649+
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
650+
mm_kwargs,
651+
device=self.device,
652+
pin_memory=self.pin_memory,
653+
):
654+
mm_kwargs_combined.update(mm_kwargs_group)
668655

669-
dummy_modality = mm_budget.get_modality_with_max_tokens()
656+
return mm_kwargs_combined
670657

671-
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
658+
def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
659+
if not self.is_multimodal_raw_input_supported:
660+
return {}
661+
mm_budget = self.mm_budget
662+
assert mm_budget is not None
672663

673-
return {}
664+
dummy_modality = mm_budget.get_modality_with_max_tokens()
665+
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
674666

675667
def _get_cumsum_and_arange(
676668
self,
@@ -1612,6 +1604,7 @@ def execute_model(
16121604
batch_descriptor=batch_descriptor,
16131605
), self.maybe_get_kv_connector_output(
16141606
scheduler_output) as kv_connector_output:
1607+
16151608
model_output = self.model(
16161609
input_ids=input_ids,
16171610
positions=positions,

vllm/worker/worker_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None:
544544
Arguments are passed to the worker class constructor.
545545
"""
546546
kwargs = all_kwargs[self.rpc_rank]
547-
self.vllm_config = kwargs.get("vllm_config", None)
547+
self.vllm_config = kwargs.get("vllm_config")
548548
assert self.vllm_config is not None, (
549549
"vllm_config is required to initialize the worker")
550550
enable_trace_function_call_for_thread(self.vllm_config)

0 commit comments

Comments
 (0)