Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus

from .utils import create_request, create_scheduler, create_vllm_config

Expand Down Expand Up @@ -1023,3 +1024,68 @@ def test_shutdown_cleans_up_resources(dist_init):
assert mock_dereg.call_count == 2
mock_dereg.assert_any_call("desc1")
mock_dereg.assert_any_call("desc2")


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
def test_aborted_request_removed_from_worker_in_batch(dist_init):
"""
Create and schedule a request so that P adds it to in-batch tracking via
the real scheduler, then simulate an abort (request not in next scheduler
iteration) and verify the worker no longer tracks it as in-batch.
"""
vllm_config = create_vllm_config()

scheduler = create_scheduler(vllm_config)
# KVConnector Worker in P
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
connector.engine_id,
hand_shake_latency=0)

# Create a request that triggers do_remote_decode so that
# the scheduler adds it to reqs_in_batch
req = create_request(request_id=1, do_remote_decode=True, max_tokens=1)
scheduler.add_request(req)

# First scheduling pass - examinate build_connector_meta output
sched_out = scheduler.schedule()
kv_meta = sched_out.kv_connector_metadata
assert kv_meta is not None
assert isinstance(kv_meta, NixlConnectorMetadata)
assert req.request_id in kv_meta.reqs_in_batch

#### Model Runner start ####
# Bind scheduler-produced metadata and start worker processing.
connector.bind_connector_metadata(kv_meta)

dummy_ctx = ForwardContext(
no_compile_layers={},
attn_metadata={},
virtual_engine=0,
)
connector.start_load_kv(dummy_ctx)

# Ensure it was tracked by the worker
assert req.request_id in connector.connector_worker._reqs_to_process

#### Model Runner end ####

# Abort request - request_finished call in connector scheduler
scheduler.finish_requests(req.request_id, RequestStatus.FINISHED_ABORTED)
# Second scheduling pass - build metadata with aborted request
sched_out2 = scheduler.schedule()
kv_meta2 = sched_out2.kv_connector_metadata
assert kv_meta2 is not None
assert isinstance(kv_meta2, NixlConnectorMetadata)
assert req.request_id not in kv_meta2.reqs_in_batch

# Bind empty/abort metadata and run worker step
#### Model Runner start ####
connector.bind_connector_metadata(kv_meta2)
connector.start_load_kv(dummy_ctx)

# After abort, the worker should not keep tracking it as "in-batch"
assert req.request_id not in connector.connector_worker._reqs_to_process
#### Model Runner end ####
18 changes: 16 additions & 2 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(self):
self.reqs_to_save: dict[ReqId, ReqMeta] = {}
self.reqs_to_send: dict[ReqId, float] = {}
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()

def add_new_req(
self,
Expand Down Expand Up @@ -287,6 +288,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
# Reqs to remove from processed set because they're not to send after
# remote prefill or aborted.
self._reqs_not_processed: set[ReqId] = set()

def get_num_new_matched_tokens(
self, request: "Request",
Expand Down Expand Up @@ -401,11 +405,13 @@ def build_connector_meta(

meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
meta.reqs_not_processed = self._reqs_not_processed

# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}

return meta
Expand Down Expand Up @@ -439,8 +445,12 @@ def request_finished(
params["do_remote_prefill"] = False
return False, None

if (not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
if not params.get("do_remote_decode"):
return False, None
if request.status != RequestStatus.FINISHED_LENGTH_CAPPED:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
return False, None

# TODO: check whether block_ids actually ever be 0. If not we could
Expand Down Expand Up @@ -1234,6 +1244,10 @@ def start_load_kv(self, metadata: NixlConnectorMetadata):
for req_id in metadata.reqs_in_batch:
self._reqs_to_process.add(req_id)

# Remove all requests that are not to be processed (eg aborted).
for req_id in metadata.reqs_not_processed:
self._reqs_to_process.discard(req_id)

# Add to requests that are waiting to be read and track expiration.
for req_id, expiration_time in metadata.reqs_to_send.items():
if req_id in self._reqs_to_process:
Expand Down