Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
NixlConnectorMetadata,
NixlConnectorWorker,
NixlKVConnectorStats,
ReqState,
)
from vllm.distributed.kv_transfer.kv_transfer_state import (
ensure_kv_transfer_shutdown,
Expand Down Expand Up @@ -1245,7 +1246,8 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
kv_meta = sched_out.kv_connector_metadata
assert kv_meta is not None
assert isinstance(kv_meta, NixlConnectorMetadata)
assert req.request_id in kv_meta.reqs_in_batch
assert req.request_id in kv_meta.reqs_to_send
assert kv_meta.reqs_to_send[req.request_id] == ReqState.SCHEDULED

#### Model Runner start ####
# Bind scheduler-produced metadata and start worker processing.
Expand All @@ -1270,7 +1272,8 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init):
kv_meta2 = sched_out2.kv_connector_metadata
assert kv_meta2 is not None
assert isinstance(kv_meta2, NixlConnectorMetadata)
assert req.request_id not in kv_meta2.reqs_in_batch
assert req.request_id in kv_meta2.reqs_to_send
assert kv_meta2.reqs_to_send[req.request_id] == ReqState.ABORTED

# Bind empty/abort metadata and run worker step
#### Model Runner start ####
Expand Down
70 changes: 31 additions & 39 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
import enum
import logging
import math
import os
Expand Down Expand Up @@ -113,13 +114,17 @@ class ReqMeta:
tp_size: int


class ReqState(enum.Enum):
SCHEDULED = 1
FINISHED = 2
ABORTED = 3


class NixlConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.reqs_to_recv: dict[ReqId, ReqMeta] = {}
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()
self.reqs_to_send: dict[ReqId, ReqState] = {}

def add_new_req(
self,
Expand Down Expand Up @@ -306,12 +311,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()
# Reqs to send state updates
self._reqs_need_send: dict[ReqId, ReqState] = {}

def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
Expand Down Expand Up @@ -364,7 +365,7 @@ def update_state_after_alloc(
return

if params.get("do_remote_decode"):
self._reqs_in_batch.add(request.request_id)
self._reqs_need_send[request.request_id] = ReqState.SCHEDULED
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
Expand Down Expand Up @@ -437,14 +438,10 @@ def build_connector_meta(
)

meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
meta.reqs_not_processed = self._reqs_not_processed

# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}

return meta
Expand Down Expand Up @@ -487,24 +484,15 @@ def request_finished(
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
self._reqs_need_send[request.request_id] = ReqState.ABORTED
return False, None

# TODO: check whether block_ids actually ever be 0. If not we could
# remove the conditional below
delay_free_blocks = len(block_ids) > 0

if delay_free_blocks:
# Prefill request on remote. It will be read from D upon completion
logger.debug(
"NIXLConnector request_finished(%s) waiting for %d seconds "
"for remote decode to fetch blocks",
request.request_id,
envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
)
self._reqs_need_send[request.request_id] = (
time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
)
self._reqs_need_send[request.request_id] = ReqState.FINISHED

return delay_free_blocks, dict(
do_remote_prefill=True,
Expand Down Expand Up @@ -1436,7 +1424,7 @@ def get_finished(self) -> tuple[set[str], set[str]]:
self.sync_recved_kv_to_device(req_id, meta)

# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
now = time.monotonic()
while self._reqs_to_send:
req_id, expires = next(iter(self._reqs_to_send.items()))
# Sorted dict, oldest requests are put first so we can exit early.
Expand Down Expand Up @@ -1575,19 +1563,23 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
# which blocks are read from D. As P can now more easily lag behind D
# while processing the next batch, we make sure to only set an
# expiration for requests that have not been read from D yet.
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)

# Remove all requests that are not to be processed (eg aborted).
for req_id in metadata.reqs_not_processed:
self._reqs_to_process.discard(req_id)
# We should never get an abort after setting an expiry timer
assert req_id not in self._reqs_to_send

# Add to requests that are waiting to be read and track expiration.
for req_id, expiration_time in metadata.reqs_to_send.items():
if req_id in self._reqs_to_process:
self._reqs_to_send[req_id] = expiration_time
for req_id, req_state in metadata.reqs_to_send.items():
if req_state == ReqState.SCHEDULED:
self._reqs_to_process.add(req_id)
elif req_state == ReqState.ABORTED:
# Remove all requests that are not to be processed (eg aborted).
self._reqs_to_process.discard(req_id)
# We should never get an abort after setting an expiry timer
assert req_id not in self._reqs_to_send
elif req_state == ReqState.FINISHED and req_id in self._reqs_to_process:
# Add to requests that are waiting to be read and track expiration.
abort_timeout = envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT
logger.debug(
"req %s : waiting %d seconds for remote decode to fetch blocks",
req_id,
envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT,
)
self._reqs_to_send[req_id] = time.monotonic() + abort_timeout

def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
logger.debug(
Expand Down