diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index bb77c4f2b62a..53539397cd11 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -29,6 +29,7 @@ NixlConnectorMetadata, NixlConnectorWorker, NixlKVConnectorStats, + ReqState, ) from vllm.distributed.kv_transfer.kv_transfer_state import ( ensure_kv_transfer_shutdown, @@ -1245,7 +1246,8 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init): kv_meta = sched_out.kv_connector_metadata assert kv_meta is not None assert isinstance(kv_meta, NixlConnectorMetadata) - assert req.request_id in kv_meta.reqs_in_batch + assert req.request_id in kv_meta.reqs_to_send + assert kv_meta.reqs_to_send[req.request_id] == ReqState.SCHEDULED #### Model Runner start #### # Bind scheduler-produced metadata and start worker processing. @@ -1270,7 +1272,8 @@ def test_aborted_request_removed_from_worker_in_batch(dist_init): kv_meta2 = sched_out2.kv_connector_metadata assert kv_meta2 is not None assert isinstance(kv_meta2, NixlConnectorMetadata) - assert req.request_id not in kv_meta2.reqs_in_batch + assert req.request_id in kv_meta2.reqs_to_send + assert kv_meta2.reqs_to_send[req.request_id] == ReqState.ABORTED # Bind empty/abort metadata and run worker step #### Model Runner start #### diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ae7144cf7847..b60fd46b3d32 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import contextlib import copy +import enum import logging import math import os @@ -113,13 +114,17 @@ class ReqMeta: tp_size: int +class ReqState(enum.Enum): + SCHEDULED = 1 + FINISHED = 2 + ABORTED = 3 + + class NixlConnectorMetadata(KVConnectorMetadata): def __init__(self): self.reqs_to_recv: dict[ReqId, ReqMeta] = {} self.reqs_to_save: dict[ReqId, ReqMeta] = {} - self.reqs_to_send: dict[ReqId, float] = {} - self.reqs_in_batch: set[ReqId] = set() - self.reqs_not_processed: set[ReqId] = set() + self.reqs_to_send: dict[ReqId, ReqState] = {} def add_new_req( self, @@ -306,12 +311,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str): # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} - # Reqs to send and their expiration time - self._reqs_need_send: dict[ReqId, float] = {} - self._reqs_in_batch: set[ReqId] = set() - # Reqs to remove from processed set because they're not to send after - # remote prefill or aborted. - self._reqs_not_processed: set[ReqId] = set() + # Reqs to send state updates + self._reqs_need_send: dict[ReqId, ReqState] = {} def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -364,7 +365,7 @@ def update_state_after_alloc( return if params.get("do_remote_decode"): - self._reqs_in_batch.add(request.request_id) + self._reqs_need_send[request.request_id] = ReqState.SCHEDULED if self.use_host_buffer and params.get("do_remote_decode"): # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. @@ -437,14 +438,10 @@ def build_connector_meta( ) meta.reqs_to_send = self._reqs_need_send - meta.reqs_in_batch = self._reqs_in_batch - meta.reqs_not_processed = self._reqs_not_processed # Clear the list once workers start the transfers self._reqs_need_recv.clear() self._reqs_need_save.clear() - self._reqs_in_batch = set() - self._reqs_not_processed = set() self._reqs_need_send = {} return meta @@ -487,24 +484,15 @@ def request_finished( if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: # Also include the case of a P/D Prefill request with immediate # block free (eg abort). Stop tracking this request. - self._reqs_not_processed.add(request.request_id) + self._reqs_need_send[request.request_id] = ReqState.ABORTED return False, None # TODO: check whether block_ids actually ever be 0. If not we could # remove the conditional below delay_free_blocks = len(block_ids) > 0 - if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion - logger.debug( - "NIXLConnector request_finished(%s) waiting for %d seconds " - "for remote decode to fetch blocks", - request.request_id, - envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, - ) - self._reqs_need_send[request.request_id] = ( - time.perf_counter() + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT - ) + self._reqs_need_send[request.request_id] = ReqState.FINISHED return delay_free_blocks, dict( do_remote_prefill=True, @@ -1436,7 +1424,7 @@ def get_finished(self) -> tuple[set[str], set[str]]: self.sync_recved_kv_to_device(req_id, meta) # Handle timeout to avoid stranding blocks on remote. - now = time.perf_counter() + now = time.monotonic() while self._reqs_to_send: req_id, expires = next(iter(self._reqs_to_send.items())) # Sorted dict, oldest requests are put first so we can exit early. @@ -1575,19 +1563,23 @@ def start_load_kv(self, metadata: NixlConnectorMetadata): # which blocks are read from D. As P can now more easily lag behind D # while processing the next batch, we make sure to only set an # expiration for requests that have not been read from D yet. - for req_id in metadata.reqs_in_batch: - self._reqs_to_process.add(req_id) - - # Remove all requests that are not to be processed (eg aborted). - for req_id in metadata.reqs_not_processed: - self._reqs_to_process.discard(req_id) - # We should never get an abort after setting an expiry timer - assert req_id not in self._reqs_to_send - - # Add to requests that are waiting to be read and track expiration. - for req_id, expiration_time in metadata.reqs_to_send.items(): - if req_id in self._reqs_to_process: - self._reqs_to_send[req_id] = expiration_time + for req_id, req_state in metadata.reqs_to_send.items(): + if req_state == ReqState.SCHEDULED: + self._reqs_to_process.add(req_id) + elif req_state == ReqState.ABORTED: + # Remove all requests that are not to be processed (eg aborted). + self._reqs_to_process.discard(req_id) + # We should never get an abort after setting an expiry timer + assert req_id not in self._reqs_to_send + elif req_state == ReqState.FINISHED and req_id in self._reqs_to_process: + # Add to requests that are waiting to be read and track expiration. + abort_timeout = envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT + logger.debug( + "req %s : waiting %d seconds for remote decode to fetch blocks", + req_id, + envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT, + ) + self._reqs_to_send[req_id] = time.monotonic() + abort_timeout def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): logger.debug(