3636import torch .nn as nn
3737from vllm .attention import AttentionType , get_attn_backend
3838from vllm .attention .layer import Attention
39- from vllm .config import CompilationLevel , VllmConfig
39+ from vllm .config import CompilationLevel , CUDAGraphMode , VllmConfig
4040from vllm .distributed import get_tensor_model_parallel_world_size
4141from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
4242 has_kv_transfer_group )
5858from vllm .sequence import IntermediateTensors
5959from vllm .tasks import GenerationTask , SupportedTask
6060from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
61- LazyLoader , cdiv )
61+ LazyLoader , cdiv , is_pin_memory_available )
6262from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
6363 KVCacheSpec )
6464from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , LogprobsTensors ,
6565 ModelRunnerOutput )
6666from vllm .v1 .pool .metadata import PoolingMetadata
67+ from vllm .v1 .sample .logits_processor import build_logitsprocs
6768from vllm .v1 .sample .metadata import SamplingMetadata
6869from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
6970from vllm .v1 .spec_decode .ngram_proposer import NgramProposer
@@ -156,6 +157,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
156157 self .cache_config = vllm_config .cache_config
157158 self .lora_config = vllm_config .lora_config
158159 self .parallel_config = vllm_config .parallel_config
160+ self .compilation_config = vllm_config .compilation_config
161+ self .pin_memory = is_pin_memory_available ()
159162 self .scheduler_config = vllm_config .scheduler_config
160163 self .speculative_config = vllm_config .speculative_config
161164 self .block_size = vllm_config .cache_config .block_size
@@ -335,9 +338,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
335338 == CompilationLevel .PIECEWISE
336339 and not self .model_config .enforce_eager and
337340 not ascend_config .torchair_graph_config .enabled )
338- self .aclgraph_batch_sizes = list (
339- reversed (
340- self .vllm_config .compilation_config .cudagraph_capture_sizes ))
341+ self .aclgraph_batch_sizes = []
342+ if self .compilation_config .cudagraph_capture_sizes and \
343+ self .compilation_config .cudagraph_mode != CUDAGraphMode .NONE :
344+ self .aclgraph_batch_sizes = list (
345+ reversed (self .compilation_config .cudagraph_capture_sizes ))
341346
342347 self .new_kv_cache_bytes = - 1
343348 self .torchair_compiled_model = None # type: ignore
@@ -405,12 +410,6 @@ def check_batch_sizes_consistency(self) -> None:
405410 )
406411
407412 def _update_states (self , scheduler_output : "SchedulerOutput" ) -> None :
408- """Update the cached states and the persistent batch with the scheduler
409- output.
410-
411- The SamplingMetadata is updated and copied to the NPU if there is a
412- new/resumed/paused/finished request in the batch.
413- """
414413 # Remove finished requests from the cached states.
415414 for req_id in scheduler_output .finished_req_ids :
416415 self .requests .pop (req_id , None )
@@ -421,11 +420,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
421420 # then resubmitted with the same ID. In this case, we treat them as two
422421 # distinct requests - clearing the cached states for the first request
423422 # and handling the second as a new request.
424- removed_req_indices : List [int ] = []
425423 for req_id in scheduler_output .finished_req_ids :
426- req_index = self .input_batch .remove_request (req_id )
427- if req_index is not None :
428- removed_req_indices .append (req_index )
424+ self .input_batch .remove_request (req_id )
429425
430426 # Free the cached encoder outputs.
431427 for req_id , input_id in scheduler_output .free_encoder_input_ids :
@@ -448,16 +444,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
448444 # have low request overlap (e.g., alternating between two distinct
449445 # sets of requests), this optimization becomes very inefficient.
450446 for req_id in unscheduled_req_ids :
451- req_index = self .input_batch .remove_request (req_id )
452- assert req_index is not None
453- removed_req_indices .append (req_index )
447+ self .input_batch .remove_request (req_id )
454448
455- req_ids_to_add : List [str ] = []
449+ req_ids_to_add : list [str ] = []
456450 # Add new requests to the cached states.
457451 for new_req_data in scheduler_output .scheduled_new_reqs :
458452 req_id = new_req_data .req_id
459453 sampling_params = new_req_data .sampling_params
460454 pooling_params = new_req_data .pooling_params
455+
461456 if sampling_params and \
462457 sampling_params .sampling_type == SamplingType .RANDOM_SEED :
463458 generator = torch .Generator (device = self .device )
@@ -468,7 +463,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
468463 if pooling_params :
469464 assert (task := pooling_params .task ) is not None , (
470465 "You did not set `task` in the API" )
471- model = cast (VllmModelForPooling , self .model )
466+
467+ model = cast (VllmModelForPooling , self .get_model ())
472468 to_update = model .pooler .get_pooling_updates (task )
473469 to_update .apply (pooling_params )
474470
@@ -478,7 +474,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
478474 mm_kwargs = new_req_data .mm_kwargs ,
479475 mm_positions = new_req_data .mm_positions ,
480476 sampling_params = sampling_params ,
481- pooling_params = new_req_data . pooling_params ,
477+ pooling_params = pooling_params ,
482478 generator = generator ,
483479 block_ids = new_req_data .block_ids ,
484480 num_computed_tokens = new_req_data .num_computed_tokens ,
@@ -493,9 +489,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
493489 second_per_grid_ts = []
494490 audio_feature_lengths = []
495491 use_audio_in_video = False
496-
497- for item in self .requests [req_id ].mm_kwargs :
498- mm_input = item .require_data ()
492+ for mm_item in self .requests [req_id ].mm_kwargs :
493+ mm_input = mm_item .get_data ()
499494 if mm_input .get ("image_grid_thw" ) is not None :
500495 image_grid_thw .append (
501496 mm_input ["image_grid_thw" ].tolist ())
@@ -528,19 +523,24 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
528523 req_ids_to_add .append (req_id )
529524
530525 # Update the states of the running/resumed requests.
531- req_data = scheduler_output .scheduled_cached_reqs
532526 is_last_rank = get_pp_group ().is_last_rank
527+ req_data = scheduler_output .scheduled_cached_reqs
533528 for i , req_id in enumerate (req_data .req_ids ):
534529 req_state = self .requests [req_id ]
535530 num_computed_tokens = req_data .num_computed_tokens [i ]
536531 new_block_ids = req_data .new_block_ids [i ]
537532 resumed_from_preemption = req_data .resumed_from_preemption [i ]
538533
534+ # Update the cached states.
539535 req_state .num_computed_tokens = num_computed_tokens
536+
540537 if not is_last_rank :
538+ # When using PP, the scheduler sends the sampled tokens back,
539+ # because there's no direct communication between the first-
540+ # stage worker and the last-stage worker.
541541 new_token_ids = req_data .new_token_ids [i ]
542542 # Add the sampled token(s) from the previous step (if any).
543- # This doesn't include "unverified" tokens like spec decode tokens.
543+ # This doesn't include "unverified" tokens like spec tokens.
544544 num_new_tokens = (num_computed_tokens + len (new_token_ids ) -
545545 req_state .num_tokens )
546546 if num_new_tokens == 1 :
@@ -549,11 +549,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
549549 elif num_new_tokens > 0 :
550550 req_state .output_token_ids .extend (
551551 new_token_ids [- num_new_tokens :])
552+
552553 # Update the block IDs.
553554 if not resumed_from_preemption :
554555 # Append the new blocks to the existing block IDs.
555- for block_ids , new_ids in zip ( # type: ignore[call-overload]
556- req_state . block_ids , new_block_ids ):
556+ for block_ids , new_ids in zip (req_state . block_ids ,
557+ new_block_ids ):
557558 block_ids .extend (new_ids )
558559 else :
559560 # The request is resumed from preemption.
@@ -571,9 +572,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
571572 # Update the persistent batch.
572573 self .input_batch .num_computed_tokens_cpu [req_index ] = (
573574 num_computed_tokens )
574-
575575 self .input_batch .block_table .append_row (new_block_ids , req_index )
576576
577+ # For the last rank, we don't need to update the token_ids_cpu
578+ # because the sampled tokens are already cached.
577579 if not is_last_rank :
578580 # Add new_token_ids to token_ids_cpu.
579581 start_token_index = num_computed_tokens
@@ -583,9 +585,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
583585 start_token_index :end_token_index ] = new_token_ids
584586 self .input_batch .num_tokens_no_spec [
585587 req_index ] = end_token_index
588+ self .input_batch .num_tokens [req_index ] = end_token_index
589+
586590 # Add spec_token_ids to token_ids_cpu.
587- spec_token_ids = scheduler_output . scheduled_spec_decode_tokens . get (
588- req_id , ())
591+ spec_token_ids = (
592+ scheduler_output . scheduled_spec_decode_tokens . get ( req_id , () ))
589593 if spec_token_ids :
590594 num_spec_tokens = len (spec_token_ids )
591595 start_index = self .input_batch .num_tokens_no_spec [req_index ]
@@ -595,39 +599,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
595599 # NOTE(woosuk): `num_tokens` here may include spec tokens.
596600 self .input_batch .num_tokens [req_index ] += num_spec_tokens
597601
598- # Check if the batch has changed. If not, we can skip copying the
599- # sampling metadata from CPU to GPU.
600- batch_changed = len (removed_req_indices ) > 0 or len (req_ids_to_add ) > 0
601-
602602 # Add the new or resumed requests to the persistent batch.
603603 # The smaller empty indices are filled first.
604- removed_req_indices .sort (reverse = True )
605604 for req_id in req_ids_to_add :
606605 req_state = self .requests [req_id ]
607- if removed_req_indices :
608- # Fill the empty index.
609- req_index = removed_req_indices .pop ()
610- else :
611- # Append to the end.
612- req_index = None
613- self .input_batch .add_request (req_state , req_index )
614- spec_token_ids = scheduler_output .scheduled_spec_decode_tokens .get (
615- req_id , ())
616- if spec_token_ids :
617- req_index = self .input_batch .num_reqs - 1
618- start_index = len (req_state .prompt_token_ids ) + len (
619- req_state .output_token_ids )
620- end_token_index = start_index + len (spec_token_ids )
621- self .input_batch .token_ids_cpu [
622- req_index , start_index :end_token_index ] = spec_token_ids
623- self .input_batch .num_tokens [req_index ] = end_token_index
606+ self .input_batch .add_request (req_state )
624607
625- # Condense the batched states if there are empty indices.
626- if removed_req_indices :
627- self .input_batch .condense (removed_req_indices )
608+ # Condense the batched states if there are gaps left by removed requests
609+ self .input_batch .condense ()
628610
629- if batch_changed :
630- self .input_batch .refresh_sampling_metadata ()
611+ # Refresh batch metadata with any pending updates.
612+ self .input_batch .refresh_metadata ()
631613
632614 def _get_forward_metadata_across_dp (
633615 self , num_tokens : int , with_prefill : bool ,
@@ -1063,11 +1045,6 @@ def _process_reqs(
10631045 num_input_tokens )
10641046 num_input_tokens += num_pad
10651047
1066- modified_batch = self .attn_metadata_builder .reorder_batch (
1067- self .input_batch , scheduler_output )
1068- if modified_batch :
1069- self .input_batch .refresh_sampling_metadata ()
1070-
10711048 # OPTIMIZATION: Start copying the block table first.
10721049 # This way, we can overlap the copy with the following CPU operations.
10731050 self .input_batch .block_table .commit_block_table (num_reqs )
@@ -2199,10 +2176,15 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
21992176 max_model_len = self .model_config .max_model_len ,
22002177 max_num_batched_tokens = self .max_num_tokens ,
22012178 device = self .device ,
2202- pin_memory = True ,
2179+ pin_memory = self . pin_memory ,
22032180 vocab_size = self .model_config .get_vocab_size (),
22042181 block_sizes = [self .block_size ],
22052182 is_spec_decode = bool (self .vllm_config .speculative_config ),
2183+ logitsprocs = build_logitsprocs (
2184+ self .vllm_config , self .device , self .pin_memory ,
2185+ self .is_pooling_model ,
2186+ self .vllm_config .model_config .logits_processors ),
2187+ is_pooling_model = self .is_pooling_model ,
22062188 )
22072189
22082190 kv_cache_sizes = {}
0 commit comments