From d5c47a5c2620843cb1af0277ff17768f5e20e057 Mon Sep 17 00:00:00 2001 From: flesher0813 <1208954694@qq.com> Date: Mon, 28 Jul 2025 10:58:23 +0800 Subject: [PATCH] [Feature]:Add support for the vLLM V1 connector Signed-off-by: flesher0813 <1208954694@qq.com> --- vllm_ascend/attention/attention_v1.py | 33 ++++++++++++++++++++++++ vllm_ascend/worker/model_runner_v1.py | 37 ++++++++++++++++++++++++--- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 7d7f488f47..915feb7a2b 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -24,6 +24,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import direct_register_custom_op from vllm.v1.core.sched.output import SchedulerOutput @@ -444,6 +447,8 @@ def unified_ascend_attention_with_output( output: torch.Tensor, layer_name: str, ) -> None: + wait_for_kv_layer_from_connector(layer_name) + forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.no_compile_layers[layer_name] @@ -456,8 +461,36 @@ def unified_ascend_attention_with_output( attn_metadata, output, trace_flag=False) + maybe_save_kv_layer_to_connector(layer_name, kv_cache) return +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + connector.wait_for_layer_load(layer_name) + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + connector.save_kv_layer(layer_name, kv_cache_layer, + attn_metadata) def unified_attention_with_output_fake( query: torch.Tensor, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index eabcdbcc19..f9cca93513 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -39,7 +39,10 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (get_dp_group, get_pp_group, get_tp_group) -from vllm.forward_context import set_forward_context +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group) +from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1 +from vllm.forward_context import set_forward_context, get_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import logger from vllm.model_executor.layers.fused_moe import FusedMoE @@ -876,7 +879,8 @@ def _process_reqs( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> tuple[Union[AscendMetadata, AscendMLAMetadata, AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata, - torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]: + torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray, + Optional[dict[str, list[str]]]]: # Check input valid total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 @@ -1100,6 +1104,7 @@ def _process_reqs( positions = self.positions[:padded_batch_size] # Run forward pass + finished_dumping = None with set_forward_context(attn_metadata, self.vllm_config, num_tokens=num_input_tokens): @@ -1125,6 +1130,7 @@ def _process_reqs( assert self.model is not None maybe_converting_weight_acl_format(self.model, ACL_FORMAT_FRACTAL_ND) + self.maybe_setup_kv_connector(scheduler_output) hidden_states = self.model( input_ids=input_ids, @@ -1133,6 +1139,7 @@ def _process_reqs( inputs_embeds=inputs_embeds, **model_kwargs, ) + finished_dumping = self.maybe_wait_for_kv_save() use_spec_decode = len( scheduler_output.scheduled_spec_decode_tokens) > 0 @@ -1163,7 +1170,7 @@ def _process_reqs( return (attn_metadata, hidden_states, spec_decode_metadata, positions, total_num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens) + num_scheduled_tokens, finished_dumping) def _get_cumsum_and_arange( self, @@ -1400,7 +1407,7 @@ def execute_model( return EMPTY_MODEL_RUNNER_OUTPUT (attn_metadata, hidden_states, spec_decode_metadata, positions, num_scheduled_tokens, logits_indices, aux_hidden_states, - num_scheduled_tokens_np) = (self._process_reqs( + num_scheduled_tokens_np, finished_dumping) = (self._process_reqs( scheduler_output, intermediate_tensors)) with ProfileExecuteDuration().capture_async("post process"): @@ -1561,6 +1568,7 @@ def execute_model( logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], + finished_dumping=finished_dumping ) durations = ProfileExecuteDuration().pop_captured_sync() @@ -2369,3 +2377,24 @@ def select_torchair_padded_batch_size(self, batch_size: int): if batch_size <= padded_batch_size < selected_batch_size: selected_batch_size = padded_batch_size return selected_batch_size + + @staticmethod + def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): + # Update KVConnector with the KVConnector metadata forward(). + if has_kv_transfer_group(): + kv_connector = get_kv_transfer_group() + assert isinstance(kv_connector, KVConnectorBase_V1) + assert scheduler_output.kv_connector_metadata is not None + kv_connector.bind_connector_metadata( + scheduler_output.kv_connector_metadata) + + # Background KV cache transfers happen here. + # These transfers are designed to be async and the requests + # involved may be disjoint from the running requests. + # Do this here to save a collective_rpc. + kv_connector.start_load_kv(get_forward_context()) + + @staticmethod + def maybe_wait_for_kv_save(): + if has_kv_transfer_group(): + return get_kv_transfer_group().wait_for_save()