Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions vllm_ascend/distributed/llmdatadist_c_mgr_connector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import copy
import json
import math
import os
Expand All @@ -17,6 +18,7 @@
import zmq
from llm_datadist import (BlocksCacheKey, CacheDesc, LLMConfig, LLMDataDist,
LLMException, LLMRole)
from vllm import envs
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
Expand Down Expand Up @@ -184,6 +186,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: Optional[str]):
self.port = dp_rank_local * tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT if dp_rank_local is not None else tp_size + envs_ascend.VLLM_ASCEND_LLMDD_RPC_PORT

self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}

def get_num_new_matched_tokens(
self, request: "Request",
Expand Down Expand Up @@ -248,7 +251,12 @@ def build_connector_meta(
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params)

meta.reqs_to_send = copy.deepcopy(self._reqs_need_send)

# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_send.clear()

return meta

Expand All @@ -275,6 +283,9 @@ def request_finished(
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(computed_block_ids), request.request_id)
# Prefill request on remote. It will be read from D upon completion
self._reqs_need_send[request.request_id] = time.perf_counter(
) + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
Expand Down Expand Up @@ -341,6 +352,7 @@ def __init__(self, vllm_config: VllmConfig):
os.environ["HCCL_DETERMINISTIC"] = "true"
self.done_receiving_counts: defaultdict[str,
set[int]] = defaultdict(set)
self.reqs_to_send: dict[str, float] = {}

def listen_for_agent_metadata_req(self, event: threading.Event):
assert self.local_agent_metadata is not None
Expand Down Expand Up @@ -379,7 +391,9 @@ def listen_for_agent_metadata_req(self, event: threading.Event):
logger.debug(
f"LLMDataDistCMgrConnectorWorker: Receiving request {finished_req_id} finished"
)
self.finished_reqs.add(finished_req_id)
if finished_req_id in self.reqs_to_send:
self.finished_reqs.add(finished_req_id)
del self.reqs_to_send[finished_req_id]
sock.send_multipart(
(identity, b"", b"receiving decode finished"))
else:
Expand Down Expand Up @@ -582,6 +596,7 @@ def handle_exception(future):

for future in futures:
future.add_done_callback(handle_exception)
self.reqs_to_send.update(metadata.reqs_to_send)

def add_remote_agent(self, metadata: LLMDataDistCMgrAgentMetadata) -> int:
assert self.local_agent_metadata is not None
Expand Down Expand Up @@ -839,8 +854,19 @@ def get_finished(
self, finished_req_ids: set[str]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""Get the finished recving and sending requuests."""
import copy
now = time.perf_counter()
with self.thread_lock:
while self.reqs_to_send:
req_id, expires = next(iter(self.reqs_to_send.items()))
if now < expires:
break
logger.warning(
"Some requests in prefill node fail to receive KV Cache transfer done signal. "
"If a greater mean TTFT is acceptable, you can 'export VLLM_NIXL_ABORT_REQUEST_TIMEOUT=600' (10 minutes) to relax the timeout condition. "
)
if req_id in self.reqs_to_send:
self.finished_reqs.add(req_id)
del self.reqs_to_send[req_id]
req_ids_to_ret = copy.deepcopy(self.finished_reqs)
self.finished_reqs.clear()
if self.llm_datadist_role == LLMRole.PROMPT:
Expand Down Expand Up @@ -871,4 +897,4 @@ def zmq_ctx(socket_type: Any,
yield socket
finally:
if ctx is not None:
ctx.destroy(linger=0)
ctx.destroy(linger=0)
3 changes: 2 additions & 1 deletion vllm_ascend/distributed/mooncake_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
Expand Down Expand Up @@ -100,7 +101,7 @@ def _retrieve_expired_requests(self):
while self.delayed_free_requests:
request_id, delay_start_time = self.delayed_free_requests[0]
if (current_time - delay_start_time
> envs_ascend.VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT):
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
self.delayed_free_requests.popleft()
expired_requests.add(request_id)
logger.info("Force freed request: %s", request_id)
Expand Down
7 changes: 1 addition & 6 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,6 @@
# caused by the initialization of the Mooncake connector.
"PHYSICAL_DEVICES":
lambda: os.getenv("PHYSICAL_DEVICES", None),
# Timeout (in seconds) for delayed KVCache block release. In the prefill
# node, if a request is marked for delayed KV block release and the blocks
# are not freed within this timeout, they will be forcibly released.
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
}

# end-env-vars-definition
Expand All @@ -177,4 +172,4 @@ def __getattr__(name: str):


def __dir__():
return list(env_variables.keys())
return list(env_variables.keys())
Loading