Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down Expand Up @@ -444,6 +447,8 @@ def unified_ascend_attention_with_output(
output: torch.Tensor,
layer_name: str,
) -> None:
wait_for_kv_layer_from_connector(layer_name)

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
Expand All @@ -456,8 +461,36 @@ def unified_ascend_attention_with_output(
attn_metadata,
output,
trace_flag=False)
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return

def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return

connector = get_kv_transfer_group()

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.wait_for_layer_load(layer_name)

def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return

connector = get_kv_transfer_group()

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
return
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata)

def unified_attention_with_output_fake(
query: torch.Tensor,
Expand Down
37 changes: 33 additions & 4 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import set_forward_context
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.forward_context import set_forward_context, get_forward_context
from vllm.inputs import INPUT_REGISTRY
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
Expand Down Expand Up @@ -876,7 +879,8 @@ def _process_reqs(
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> tuple[Union[AscendMetadata, AscendMLAMetadata,
AscendTorchairMetadata], torch.Tensor, SpecDecodeMetadata,
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]:
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray,
Optional[dict[str, list[str]]]]:
# Check input valid
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
Expand Down Expand Up @@ -1100,6 +1104,7 @@ def _process_reqs(
positions = self.positions[:padded_batch_size]

# Run forward pass
finished_dumping = None
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
Expand All @@ -1125,6 +1130,7 @@ def _process_reqs(
assert self.model is not None
maybe_converting_weight_acl_format(self.model,
ACL_FORMAT_FRACTAL_ND)
self.maybe_setup_kv_connector(scheduler_output)

hidden_states = self.model(
input_ids=input_ids,
Expand All @@ -1133,6 +1139,7 @@ def _process_reqs(
inputs_embeds=inputs_embeds,
**model_kwargs,
)
finished_dumping = self.maybe_wait_for_kv_save()

use_spec_decode = len(
scheduler_output.scheduled_spec_decode_tokens) > 0
Expand Down Expand Up @@ -1163,7 +1170,7 @@ def _process_reqs(

return (attn_metadata, hidden_states, spec_decode_metadata, positions,
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens)
num_scheduled_tokens, finished_dumping)

def _get_cumsum_and_arange(
self,
Expand Down Expand Up @@ -1400,7 +1407,7 @@ def execute_model(
return EMPTY_MODEL_RUNNER_OUTPUT
(attn_metadata, hidden_states, spec_decode_metadata, positions,
num_scheduled_tokens, logits_indices, aux_hidden_states,
num_scheduled_tokens_np) = (self._process_reqs(
num_scheduled_tokens_np, finished_dumping) = (self._process_reqs(
scheduler_output, intermediate_tensors))

with ProfileExecuteDuration().capture_async("post process"):
Expand Down Expand Up @@ -1561,6 +1568,7 @@ def execute_model(
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_dumping=finished_dumping
)

durations = ProfileExecuteDuration().pop_captured_sync()
Expand Down Expand Up @@ -2369,3 +2377,24 @@ def select_torchair_padded_batch_size(self, batch_size: int):
if batch_size <= padded_batch_size < selected_batch_size:
selected_batch_size = padded_batch_size
return selected_batch_size

@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
if has_kv_transfer_group():
kv_connector = get_kv_transfer_group()
assert isinstance(kv_connector, KVConnectorBase_V1)
assert scheduler_output.kv_connector_metadata is not None
kv_connector.bind_connector_metadata(
scheduler_output.kv_connector_metadata)

# Background KV cache transfers happen here.
# These transfers are designed to be async and the requests
# involved may be disjoint from the running requests.
# Do this here to save a collective_rpc.
kv_connector.start_load_kv(get_forward_context())

@staticmethod
def maybe_wait_for_kv_save():
if has_kv_transfer_group():
return get_kv_transfer_group().wait_for_save()
Loading