3939from vllm .distributed import get_tensor_model_parallel_world_size
4040from vllm .distributed .parallel_state import (get_dp_group , get_pp_group ,
4141 get_tp_group )
42- from vllm .forward_context import set_forward_context
42+ from vllm .distributed .kv_transfer import (get_kv_transfer_group ,
43+ has_kv_transfer_group )
44+ from vllm .distributed .kv_transfer .kv_connector .v1 import KVConnectorBase_V1
45+ from vllm .forward_context import set_forward_context , get_forward_context
4346from vllm .inputs import INPUT_REGISTRY
4447from vllm .logger import logger
4548from vllm .model_executor .layers .fused_moe import FusedMoE
@@ -876,7 +879,8 @@ def _process_reqs(
876879 intermediate_tensors : Optional [IntermediateTensors ] = None ,
877880 ) -> tuple [Union [AscendMetadata , AscendMLAMetadata ,
878881 AscendTorchairMetadata ], torch .Tensor , SpecDecodeMetadata ,
879- torch .Tensor , int , torch .Tensor , torch .Tensor , np .ndarray ]:
882+ torch .Tensor , int , torch .Tensor , torch .Tensor , np .ndarray ,
883+ Optional [dict [str , list [str ]]]]:
880884 # Check input valid
881885 total_num_scheduled_tokens = scheduler_output .total_num_scheduled_tokens
882886 assert total_num_scheduled_tokens > 0
@@ -1100,6 +1104,7 @@ def _process_reqs(
11001104 positions = self .positions [:padded_batch_size ]
11011105
11021106 # Run forward pass
1107+ finished_dumping = None
11031108 with set_forward_context (attn_metadata ,
11041109 self .vllm_config ,
11051110 num_tokens = num_input_tokens ):
@@ -1125,6 +1130,7 @@ def _process_reqs(
11251130 assert self .model is not None
11261131 maybe_converting_weight_acl_format (self .model ,
11271132 ACL_FORMAT_FRACTAL_ND )
1133+ self .maybe_setup_kv_connector (scheduler_output )
11281134
11291135 hidden_states = self .model (
11301136 input_ids = input_ids ,
@@ -1133,6 +1139,7 @@ def _process_reqs(
11331139 inputs_embeds = inputs_embeds ,
11341140 ** model_kwargs ,
11351141 )
1142+ finished_dumping = self .maybe_wait_for_kv_save ()
11361143
11371144 use_spec_decode = len (
11381145 scheduler_output .scheduled_spec_decode_tokens ) > 0
@@ -1163,7 +1170,7 @@ def _process_reqs(
11631170
11641171 return (attn_metadata , hidden_states , spec_decode_metadata , positions ,
11651172 total_num_scheduled_tokens , logits_indices , aux_hidden_states ,
1166- num_scheduled_tokens )
1173+ num_scheduled_tokens , finished_dumping )
11671174
11681175 def _get_cumsum_and_arange (
11691176 self ,
@@ -1400,7 +1407,7 @@ def execute_model(
14001407 return EMPTY_MODEL_RUNNER_OUTPUT
14011408 (attn_metadata , hidden_states , spec_decode_metadata , positions ,
14021409 num_scheduled_tokens , logits_indices , aux_hidden_states ,
1403- num_scheduled_tokens_np ) = (self ._process_reqs (
1410+ num_scheduled_tokens_np , finished_dumping ) = (self ._process_reqs (
14041411 scheduler_output , intermediate_tensors ))
14051412
14061413 with ProfileExecuteDuration ().capture_async ("post process" ):
@@ -1561,6 +1568,7 @@ def execute_model(
15611568 logprobs = logprobs_lists ,
15621569 prompt_logprobs_dict = prompt_logprobs_dict ,
15631570 pooler_output = [],
1571+ finished_dumping = finished_dumping
15641572 )
15651573
15661574 durations = ProfileExecuteDuration ().pop_captured_sync ()
@@ -2369,3 +2377,24 @@ def select_torchair_padded_batch_size(self, batch_size: int):
23692377 if batch_size <= padded_batch_size < selected_batch_size :
23702378 selected_batch_size = padded_batch_size
23712379 return selected_batch_size
2380+
2381+ @staticmethod
2382+ def maybe_setup_kv_connector (scheduler_output : "SchedulerOutput" ):
2383+ # Update KVConnector with the KVConnector metadata forward().
2384+ if has_kv_transfer_group ():
2385+ kv_connector = get_kv_transfer_group ()
2386+ assert isinstance (kv_connector , KVConnectorBase_V1 )
2387+ assert scheduler_output .kv_connector_metadata is not None
2388+ kv_connector .bind_connector_metadata (
2389+ scheduler_output .kv_connector_metadata )
2390+
2391+ # Background KV cache transfers happen here.
2392+ # These transfers are designed to be async and the requests
2393+ # involved may be disjoint from the running requests.
2394+ # Do this here to save a collective_rpc.
2395+ kv_connector .start_load_kv (get_forward_context ())
2396+
2397+ @staticmethod
2398+ def maybe_wait_for_kv_save ():
2399+ if has_kv_transfer_group ():
2400+ return get_kv_transfer_group ().wait_for_save ()
0 commit comments