Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions vllm_ascend/distributed/kvpool/kv_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def _handle_request(self, req_meta: dict[str, Any]):
addr_list.append(addr)
size_list.append(size)
if self.dcp_size > 1:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
Expand All @@ -126,7 +125,6 @@ def _handle_request(self, req_meta: dict[str, Any]):
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if is_last_chunk:
self.set_finished_request(req_id)
Expand Down Expand Up @@ -205,7 +203,6 @@ def _handle_request( # type: ignore[override]
addr_list.append(addr)
size_list.append(size)
if self.dcp_size > 1:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list, addr_list, size_list)
else:
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
Expand All @@ -214,7 +211,6 @@ def _handle_request( # type: ignore[override]
size_list_tp = size_list[self.tp_rank %
self.put_step::self.put_step]
if key_list_tp:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
self.set_finished_request(req_meta.req_id)
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2339,7 +2339,6 @@ def execute_model(
attn_metadata, self.with_prefill, maybe_padded_num_tokens,
input_ids, positions, intermediate_tensors, inputs_embeds)

self.maybe_wait_for_kv_save()
finished_sending, finished_recving = self.get_finished_kv_transfer(
scheduler_output)

Expand Down Expand Up @@ -2603,7 +2602,7 @@ def propose_draft_token_ids(sampled_token_ids):
# ngram and other speculative decoding methods use the sampled
# tokens on the CPU, so they are run after bookkeeping.
propose_draft_token_ids(valid_sampled_token_ids)

self.maybe_wait_for_kv_save()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Moving self.maybe_wait_for_kv_save() to this location from execute_model is a critical fix for a race condition. Previously, the KV cache save operation could be triggered before the model's forward pass had completed, potentially leading to corrupted data being saved. By placing it here, after sampling operations that implicitly synchronize the device, we ensure the KV cache is fully populated and stable before initiating the save.

A minor suggestion for future improvement: the method name maybe_wait_for_kv_save is misleading as it appears to trigger an asynchronous save rather than waiting. Renaming it to something like trigger_kv_save_if_needed would improve code clarity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()

Expand Down
Loading