Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,17 @@ def get_num_new_matched_tokens(
rounded_num_prompt_tokens = round_down(
len(request.prompt_token_ids), self.block_size)
count = max(rounded_num_prompt_tokens - num_computed_tokens, 0)
return count, count > 0
if count > 0:
return count, True

# NOTE: if count is 0 here, we have less than block_size
# tokens to pull after subtracting the local prefix cache hit.
# The remote only sends fully computed blocks, so there is
# nothing to transfer but we still need to notify the
# prefill worker so that the remote blocks are freed.
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
self._reqs_need_recv[request.request_id] = (request, [])

# No remote prefill for this request.
return 0, False
Expand All @@ -224,10 +234,6 @@ def update_state_after_alloc(self, request: "Request",
num_external_tokens, params)

if params is not None and params.get("do_remote_prefill"):
# NOTE(rob): if prompt < block_size, no remote blocks
# since the remote only sends fully computed blocks, so
# skip recving for this request. num_external_tokens
# should be 0 if there are no remote blocks.
if params.get("remote_block_ids"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
Expand Down
29 changes: 18 additions & 11 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,32 +345,38 @@ def schedule(self) -> SchedulerOutput:
skipped_waiting_requests.appendleft(request)
continue

num_external_computed_tokens = 0
load_kv_async = False

# Get already-cached tokens.
if num_prealloc_computed_tokens == 0:
new_computed_blocks, num_native_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)

# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_native_computed_tokens))

# Total computed tokens (local + external).
num_computed_tokens = (num_native_computed_tokens +
num_external_computed_tokens)
else:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks = KVCacheBlocks.create_empty()
num_native_computed_tokens = 0

# Get externally-cached tokens if using a KVConnector.
num_external_computed_tokens, load_kv_async = (
(0, False) if self.connector is None else
self.connector.get_num_new_matched_tokens(
request, num_native_computed_tokens))

# Total computed tokens (local + external).
num_computed_tokens = (num_native_computed_tokens +
num_external_computed_tokens +
num_prealloc_computed_tokens)
# Total computed tokens (allocated in prior step).
num_computed_tokens = num_prealloc_computed_tokens

encoder_inputs_to_schedule = None
new_encoder_budget = encoder_budget

# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
# Number of tokens to be scheduled.
else:
Expand Down Expand Up @@ -411,7 +417,8 @@ def schedule(self) -> SchedulerOutput:
# KVConnector: update internal state after allocation.
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
if num_external_computed_tokens:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is right. We should let the connector decide what to if num_external_computed_tokens=0

For instance, this will cause a memory leak on the P worker if the D worker has a full prefix cache hit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @robertgshaw2-redhat I considered this and thought it was ok but got P/D mixed up and you're right.

It will mean we need to rethink some things w.r.t. the multi-connector impl though since we currently cycle through the connectors in get_num_new_matched_tokens until the first one that returns nonzero.

I think this could be addressed by handling that case for nixl in get_num_new_matched_tokens itself, I'll add a change for that.

assert self.connector is not None
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
Expand Down