Skip to content

Commit d5c47a5

Browse files
committed
[Feature]:Add support for the vLLM V1 connector
Signed-off-by: flesher0813 <1208954694@qq.com>
1 parent b5b7e0e commit d5c47a5

File tree

2 files changed

+66
-4
lines changed

2 files changed

+66
-4
lines changed

vllm_ascend/attention/attention_v1.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
2525
AttentionLayer, AttentionType)
2626
from vllm.attention.backends.utils import CommonAttentionState
27+
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
28+
has_kv_transfer_group,
29+
is_v1_kv_transfer_group)
2730
from vllm.forward_context import ForwardContext, get_forward_context
2831
from vllm.utils import direct_register_custom_op
2932
from vllm.v1.core.sched.output import SchedulerOutput
@@ -444,6 +447,8 @@ def unified_ascend_attention_with_output(
444447
output: torch.Tensor,
445448
layer_name: str,
446449
) -> None:
450+
wait_for_kv_layer_from_connector(layer_name)
451+
447452
forward_context: ForwardContext = get_forward_context()
448453
attn_metadata = forward_context.attn_metadata
449454
self = forward_context.no_compile_layers[layer_name]
@@ -456,8 +461,36 @@ def unified_ascend_attention_with_output(
456461
attn_metadata,
457462
output,
458463
trace_flag=False)
464+
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
459465
return
460466

467+
def wait_for_kv_layer_from_connector(layer_name: str):
468+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
469+
return
470+
471+
connector = get_kv_transfer_group()
472+
473+
forward_context: ForwardContext = get_forward_context()
474+
attn_metadata = forward_context.attn_metadata
475+
if attn_metadata is None:
476+
return
477+
connector.wait_for_layer_load(layer_name)
478+
479+
def maybe_save_kv_layer_to_connector(
480+
layer_name: str,
481+
kv_cache_layer: List[torch.Tensor],
482+
):
483+
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
484+
return
485+
486+
connector = get_kv_transfer_group()
487+
488+
forward_context: ForwardContext = get_forward_context()
489+
attn_metadata = forward_context.attn_metadata
490+
if attn_metadata is None:
491+
return
492+
connector.save_kv_layer(layer_name, kv_cache_layer,
493+
attn_metadata)
461494

462495
def unified_attention_with_output_fake(
463496
query: torch.Tensor,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@
3939
from vllm.distributed import get_tensor_model_parallel_world_size
4040
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
4141
get_tp_group)
42-
from vllm.forward_context import set_forward_context
42+
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
43+
has_kv_transfer_group)
44+
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
45+
from vllm.forward_context import set_forward_context, get_forward_context
4346
from vllm.inputs import INPUT_REGISTRY
4447
from vllm.logger import logger
4548
from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -876,7 +879,8 @@ def _process_reqs(
876879
intermediate_tensors: Optional[IntermediateTensors] = None,
877880
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
878881
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
879-
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]:
882+
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray,
883+
Optional[dict[str, list[str]]]]:
880884
# Check input valid
881885
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
882886
assert total_num_scheduled_tokens > 0
@@ -1100,6 +1104,7 @@ def _process_reqs(
11001104
positions = self.positions[:padded_batch_size]
11011105

11021106
# Run forward pass
1107+
finished_dumping = None
11031108
with set_forward_context(attn_metadata,
11041109
self.vllm_config,
11051110
num_tokens=num_input_tokens):
@@ -1125,6 +1130,7 @@ def _process_reqs(
11251130
assert self.model is not None
11261131
maybe_converting_weight_acl_format(self.model,
11271132
ACL_FORMAT_FRACTAL_ND)
1133+
self.maybe_setup_kv_connector(scheduler_output)
11281134

11291135
hidden_states = self.model(
11301136
input_ids=input_ids,
@@ -1133,6 +1139,7 @@ def _process_reqs(
11331139
inputs_embeds=inputs_embeds,
11341140
**model_kwargs,
11351141
)
1142+
finished_dumping = self.maybe_wait_for_kv_save()
11361143

11371144
use_spec_decode = len(
11381145
scheduler_output.scheduled_spec_decode_tokens) > 0
@@ -1163,7 +1170,7 @@ def _process_reqs(
11631170

11641171
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
11651172
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
1166-
num_scheduled_tokens)
1173+
num_scheduled_tokens, finished_dumping)
11671174

11681175
def _get_cumsum_and_arange(
11691176
self,
@@ -1400,7 +1407,7 @@ def execute_model(
14001407
return EMPTY_MODEL_RUNNER_OUTPUT
14011408
(attn_metadata, hidden_states, spec_decode_metadata, positions,
14021409
num_scheduled_tokens, logits_indices, aux_hidden_states,
1403-
num_scheduled_tokens_np) = (self._process_reqs(
1410+
num_scheduled_tokens_np, finished_dumping) = (self._process_reqs(
14041411
scheduler_output, intermediate_tensors))
14051412

14061413
with ProfileExecuteDuration().capture_async("post process"):
@@ -1561,6 +1568,7 @@ def execute_model(
15611568
logprobs=logprobs_lists,
15621569
prompt_logprobs_dict=prompt_logprobs_dict,
15631570
pooler_output=[],
1571+
finished_dumping=finished_dumping
15641572
)
15651573

15661574
durations = ProfileExecuteDuration().pop_captured_sync()
@@ -2369,3 +2377,24 @@ def select_torchair_padded_batch_size(self, batch_size: int):
23692377
if batch_size <= padded_batch_size < selected_batch_size:
23702378
selected_batch_size = padded_batch_size
23712379
return selected_batch_size
2380+
2381+
@staticmethod
2382+
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
2383+
# Update KVConnector with the KVConnector metadata forward().
2384+
if has_kv_transfer_group():
2385+
kv_connector = get_kv_transfer_group()
2386+
assert isinstance(kv_connector, KVConnectorBase_V1)
2387+
assert scheduler_output.kv_connector_metadata is not None
2388+
kv_connector.bind_connector_metadata(
2389+
scheduler_output.kv_connector_metadata)
2390+
2391+
# Background KV cache transfers happen here.
2392+
# These transfers are designed to be async and the requests
2393+
# involved may be disjoint from the running requests.
2394+
# Do this here to save a collective_rpc.
2395+
kv_connector.start_load_kv(get_forward_context())
2396+
2397+
@staticmethod
2398+
def maybe_wait_for_kv_save():
2399+
if has_kv_transfer_group():
2400+
return get_kv_transfer_group().wait_for_save()

0 commit comments

Comments
 (0)