diff --git a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch index da340e7e..2c3fd6f0 100644 --- a/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch +++ b/ucm/integration/vllm/patch/0.9.2/vllm-adapt.patch @@ -1,29 +1,30 @@ -From 555ba9e4920445381aecda262b9146342e92eeee Mon Sep 17 00:00:00 2001 -From: hek14 <1023129548@qq.com> -Date: Fri, 26 Sep 2025 09:51:36 +0800 -Subject: [PATCH] UCM adaptor +From acc0c1a279e77568d92dfc7a36dc1dbd3acdf0ca Mon Sep 17 00:00:00 2001 +From: flesher0813 <1208954694@qq.com> +Date: Fri, 17 Oct 2025 21:01:17 +0800 +Subject: [PATCH] [Feat] support aggregate and load failure --- - vllm/attention/layer.py | 45 ++++- - .../kv_transfer/kv_connector/utils.py | 113 ++++++++++++ + vllm/attention/layer.py | 45 +++- + .../kv_transfer/kv_connector/utils.py | 113 +++++++++ .../kv_transfer/kv_connector/v1/base.py | 9 + + .../kv_connector/v1/multi_connector.py | 6 + .../v1/shared_storage_connector.py | 7 +- vllm/v1/core/block_pool.py | 2 +- vllm/v1/core/kv_cache_manager.py | 11 +- - vllm/v1/core/sched/output.py | 3 + - vllm/v1/core/sched/scheduler.py | 164 +++++++++++++++++- + vllm/v1/core/sched/output.py | 5 + + vllm/v1/core/sched/scheduler.py | 217 ++++++++++++++++-- vllm/v1/core/single_type_kv_cache_manager.py | 3 + - vllm/v1/executor/multiproc_executor.py | 30 +++- - vllm/v1/outputs.py | 5 + + vllm/v1/executor/multiproc_executor.py | 30 ++- + vllm/v1/outputs.py | 7 +- vllm/v1/request.py | 2 +- vllm/v1/worker/block_table.py | 13 ++ - vllm/v1/worker/gpu_input_batch.py | 9 + - vllm/v1/worker/gpu_model_runner.py | 120 +++++++++++-- - vllm/v1/worker/gpu_worker.py | 25 ++- - 16 files changed, 524 insertions(+), 37 deletions(-) + vllm/v1/worker/gpu_input_batch.py | 14 ++ + vllm/v1/worker/gpu_model_runner.py | 104 +++++++-- + vllm/v1/worker/gpu_worker.py | 25 +- + 17 files changed, 564 insertions(+), 49 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py -index f0ad68b16..2acde35d8 100644 +index f0ad68b16..26cdf0445 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -2,7 +2,6 @@ @@ -247,21 +248,38 @@ index 5cbc8ca31..8556a979e 100644 + + return result_future diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py -index f80b5eba2..8891246e6 100644 +index f80b5eba2..39d8fa389 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -201,6 +201,15 @@ class KVConnectorBase_V1(ABC): """ return None, None -+ def get_block_ids_with_load_errors(self) -> Optional[set[int]]: ++ def get_block_ids_with_load_errors(self) -> set[int]: + """ + Get the set of block IDs that failed to load. + Returns: + Optional[set[int]]: A set of block IDs that encountered load errors. + Returns None if no errors occurred during load. + """ -+ return None ++ return set() ++ + # ============================== + # Scheduler-side methods + # ============================== +diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +index be3c23399..c4fedb3a7 100644 +--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py ++++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +@@ -129,6 +129,12 @@ class MultiConnector(KVConnectorBase_V1): + + return finished_sending or None, finished_recving or None + ++ def get_block_ids_with_load_errors(self) -> set[int]: ++ agg_block_ids: set[int] = set() ++ for c in self._connectors: ++ agg_block_ids |= c.get_block_ids_with_load_errors() ++ return agg_block_ids + # ============================== # Scheduler-side methods @@ -305,7 +323,7 @@ index d21f94727..1800665c7 100644 new_full_blocks = blocks[num_cached_blocks:num_full_blocks] assert len(block_hashes) >= num_cached_blocks diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py -index 6937455e7..c36a25bc5 100644 +index 6937455e7..0099f4a0f 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -3,7 +3,7 @@ @@ -348,10 +366,26 @@ index 6937455e7..c36a25bc5 100644 if new_computed_blocks is not None: new_computed_block_list = new_computed_blocks.blocks diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py -index d34f39327..141d750b3 100644 +index d34f39327..fff0eeb42 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py -@@ -155,3 +155,6 @@ class SchedulerOutput: +@@ -93,6 +93,7 @@ class CachedRequestData: + new_token_ids: list[list[int]] + new_block_ids: list[tuple[list[int], ...]] + num_computed_tokens: list[int] ++ num_output_tokens: list[int] + + @property + def num_reqs(self) -> int: +@@ -106,6 +107,7 @@ class CachedRequestData: + new_token_ids=[], + new_block_ids=[], + num_computed_tokens=[], ++ num_output_tokens=[], + ) + + +@@ -155,3 +157,6 @@ class SchedulerOutput: # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None @@ -359,7 +393,7 @@ index d34f39327..141d750b3 100644 + # modified slots by sparse algorithm + req_sparsed_slots: dict[str, int] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index fe552db74..6a9d4b4b9 100644 +index fe552db74..aa172e943 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -34,6 +34,8 @@ from vllm.v1.outputs import ModelRunnerOutput @@ -390,7 +424,15 @@ index fe552db74..6a9d4b4b9 100644 self.kv_event_publisher = EventPublisherFactory.create( self.kv_events_config, -@@ -201,8 +209,13 @@ class Scheduler(SchedulerInterface): +@@ -118,6 +126,7 @@ class Scheduler(SchedulerInterface): + + # KV Connector: requests in process of async KV loading or recving + self.finished_recving_kv_req_ids: set[str] = set() ++ self.failed_recving_kv_req_ids: set[str] = set() + + # Encoder-related. + # Calculate encoder cache size if applicable +@@ -201,8 +210,13 @@ class Scheduler(SchedulerInterface): # First, schedule the RUNNING requests. req_index = 0 @@ -404,7 +446,7 @@ index fe552db74..6a9d4b4b9 100644 num_new_tokens = (request.num_tokens_with_spec - request.num_computed_tokens) -@@ -250,7 +263,8 @@ class Scheduler(SchedulerInterface): +@@ -250,7 +264,8 @@ class Scheduler(SchedulerInterface): request, num_new_tokens, num_draft_tokens=num_draft_tokens, @@ -414,7 +456,7 @@ index fe552db74..6a9d4b4b9 100644 if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. -@@ -337,6 +351,10 @@ class Scheduler(SchedulerInterface): +@@ -337,6 +352,10 @@ class Scheduler(SchedulerInterface): break request = self.waiting.peek_request() @@ -425,7 +467,7 @@ index fe552db74..6a9d4b4b9 100644 # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: -@@ -446,6 +464,7 @@ class Scheduler(SchedulerInterface): +@@ -446,6 +465,7 @@ class Scheduler(SchedulerInterface): new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async, @@ -433,7 +475,7 @@ index fe552db74..6a9d4b4b9 100644 ) if new_blocks is None: # The request cannot be scheduled. -@@ -559,6 +578,7 @@ class Scheduler(SchedulerInterface): +@@ -559,6 +579,7 @@ class Scheduler(SchedulerInterface): scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, @@ -441,7 +483,31 @@ index fe552db74..6a9d4b4b9 100644 # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between -@@ -745,16 +765,31 @@ class Scheduler(SchedulerInterface): +@@ -620,6 +641,7 @@ class Scheduler(SchedulerInterface): + new_token_ids: list[list[int]] = [] + new_block_ids: list[tuple[list[int], ...]] = [] + num_computed_tokens: list[int] = [] ++ num_output_tokens: list[int] = [] + + for req in itertools.chain(running_reqs, resumed_reqs): + req_id = req.request_id +@@ -637,6 +659,7 @@ class Scheduler(SchedulerInterface): + new_token_ids.append(token_ids) + new_block_ids.append(req_to_new_block_ids[req_id]) + num_computed_tokens.append(req.num_computed_tokens) ++ num_output_tokens.append(len(req.output_token_ids)) + # Because resumed_reqs is usually empty, it is more efficient to do + # in-place appending so that we don't need to allocate a new list. + resumed_from_preemption = [False] * len(running_reqs) +@@ -648,6 +671,7 @@ class Scheduler(SchedulerInterface): + new_token_ids=new_token_ids, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, ++ num_output_tokens=num_output_tokens, + ) + + def _try_schedule_encoder_inputs( +@@ -745,16 +769,29 @@ class Scheduler(SchedulerInterface): num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits @@ -451,12 +517,12 @@ index fe552db74..6a9d4b4b9 100644 outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) spec_decoding_stats: Optional[SpecDecodingStats] = None -+ recovered_req_ids = None ++ failed_kv_load_req_ids = None + if invalid_block_ids: + # These blocks contain externally computed tokens that failed to + # load. Identify affected requests and adjust their computed token + # count to trigger recomputation of the invalid blocks. -+ recovered_req_ids = self._handle_invalid_blocks(invalid_block_ids) ++ failed_kv_load_req_ids = self._handle_invalid_blocks(invalid_block_ids) + # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below # loop can be a performance bottleneck. We should do our best to avoid @@ -464,30 +530,27 @@ index fe552db74..6a9d4b4b9 100644 for request in self.running: req_id = request.request_id + # self.req_meta.stage == SequenceStage.PREFILL and self.req_meta.is_last_chunk -+ -+ -+ if recovered_req_ids and req_id in recovered_req_ids: ++ if failed_kv_load_req_ids and req_id in failed_kv_load_req_ids: + # Skip requests that were recovered from KV load failure + new_running.append(request) + continue num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) if num_tokens_scheduled == 0: # The request was not scheduled in this step. -@@ -792,6 +827,13 @@ class Scheduler(SchedulerInterface): +@@ -792,6 +829,12 @@ class Scheduler(SchedulerInterface): new_token_ids = generated_token_ids kv_transfer_params = None + if model_runner_output.finished_dumping is not None: + request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) + is_prefill = request.num_output_tokens == 0 -+ is_last_chunk = (num_tokens_scheduled + request.num_computed_tokens >= request.num_prompt_tokens) -+ if is_prefill and is_last_chunk: -+ self.connector.connector.commit(request.succeed_dumped_blocks, True) ++ if is_prefill: ++ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) + # Append generated tokens and check for stop. Note that if # a request is still being prefilled, we expect the model runner # to return empty token ids for the request. -@@ -842,7 +884,6 @@ class Scheduler(SchedulerInterface): +@@ -842,7 +885,6 @@ class Scheduler(SchedulerInterface): spec_token_ids[req_index]) else: request.spec_token_ids = spec_token_ids[req_index] @@ -495,7 +558,7 @@ index fe552db74..6a9d4b4b9 100644 # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None \ -@@ -869,6 +910,7 @@ class Scheduler(SchedulerInterface): +@@ -869,6 +911,7 @@ class Scheduler(SchedulerInterface): if not stopped: new_running.append(request) @@ -503,7 +566,7 @@ index fe552db74..6a9d4b4b9 100644 self.running = new_running # KV Connector: update state for finished KV Transfers. -@@ -927,6 +969,8 @@ class Scheduler(SchedulerInterface): +@@ -927,6 +970,8 @@ class Scheduler(SchedulerInterface): def add_request(self, request: Request) -> None: self.waiting.add_request(request) self.requests[request.request_id] = request @@ -512,7 +575,7 @@ index fe552db74..6a9d4b4b9 100644 if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) -@@ -976,6 +1020,8 @@ class Scheduler(SchedulerInterface): +@@ -976,6 +1021,8 @@ class Scheduler(SchedulerInterface): def _free_request(self, request: Request) -> Optional[dict[str, Any]]: assert request.is_finished() @@ -521,34 +584,101 @@ index fe552db74..6a9d4b4b9 100644 delay_free_blocks, kv_xfer_params = self._connector_finished(request) self.encoder_cache_manager.free(request) -@@ -1113,3 +1159,117 @@ class Scheduler(SchedulerInterface): +@@ -1078,18 +1125,31 @@ class Scheduler(SchedulerInterface): + if request.request_id not in self.finished_recving_kv_req_ids: + return False + +- # Now that the blocks are ready, actually cache them. +- (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) +- num_computed_tokens = len(block_ids) * self.block_size +- # Handle the case where num request tokens less then one block. +- num_computed_tokens = min(num_computed_tokens, request.num_tokens) +- if num_computed_tokens == request.num_tokens: +- num_computed_tokens -= 1 +- # This will cache the blocks iff caching is enabled. +- self.kv_cache_manager.cache_blocks(request, num_computed_tokens) +- +- # Update the request state for scheduling. +- request.num_computed_tokens = num_computed_tokens ++ if request.request_id in self.failed_recving_kv_req_ids: ++ # Request had KV load failures; num_computed_tokens was already ++ # updated in _update_requests_with_invalid_blocks ++ if request.num_computed_tokens: ++ # Cache any valid computed tokens. ++ self.kv_cache_manager.cache_blocks(request, ++ request.num_computed_tokens) ++ else: ++ # No valid computed tokens, release allocated blocks. ++ # There may be a local cache hit on retry. ++ self.kv_cache_manager.free(request) ++ self.failed_recving_kv_req_ids.remove(request.request_id) ++ else: ++ # Now that the blocks are ready, actually cache them. ++ (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) ++ num_computed_tokens = len(block_ids) * self.block_size ++ # Handle the case where num request tokens less then one block. ++ num_computed_tokens = min(num_computed_tokens, request.num_tokens) ++ if num_computed_tokens == request.num_tokens: ++ num_computed_tokens -= 1 ++ # This will cache the blocks iff caching is enabled. ++ self.kv_cache_manager.cache_blocks(request, num_computed_tokens) ++ ++ # Update the request state for scheduling. ++ request.num_computed_tokens = num_computed_tokens + + # Return that we are ready. + self.finished_recving_kv_req_ids.remove(request.request_id) +@@ -1113,3 +1173,132 @@ class Scheduler(SchedulerInterface): for req_id in (model_runner_output.finished_sending or ()): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) + + def _update_requests_with_invalid_blocks( + self, requests: Iterable[Request], -+ invalid_block_ids: set[int]) -> tuple[set[Request], int, set[int]]: -+ affected_requests: set[Request] = set() -+ num_tokens_to_reschedule = 0 ++ invalid_block_ids: set[int]) -> tuple[set[str], int]: ++ """ ++ Identify and update requests affected by invalid KV cache blocks. ++ This method scans the given requests, detects those with invalid blocks ++ and adjusts their `num_computed_tokens` to the longest valid prefix. ++ For observability, it also accumulates the total number of tokens that ++ will need to be recomputed across all affected requests. ++ Args: ++ requests: The set of requests to scan for invalid blocks. ++ invalid_block_ids: IDs of invalid blocks. ++ Returns: ++ tuple: ++ - affected_req_ids (set[str]): IDs of requests impacted by ++ invalid blocks. ++ - total_affected_tokens (int): Total number of tokens that must ++ be recomputed across all affected requests (for observability). ++ """ ++ affected_req_ids: set[str] = set() ++ total_affected_tokens = 0 + # If a block is invalid and shared by multiple requests in the batch, -+ # all requests must be rescheduled, but only the first will recompute ++ # these requests must be rescheduled, but only the first will recompute + # it. This set tracks blocks already marked for recomputation. + marked_invalid_block_ids: set[int] = set() + for request in requests: + is_affected = False + marked_invalid_block = False + req_id = request.request_id -+ req_block_ids = self.kv_cache_manager.get_block_ids(req_id)[0] ++ # TODO (davidb): add support for hybrid memory allocator ++ (req_block_ids, ) = self.kv_cache_manager.get_block_ids(req_id) + # We iterate only over blocks that may contain externally computed + # tokens -+ if request.num_cached_tokens > 0: -+ req_num_computed_blocks = (request.num_cached_tokens + -+ self.block_size - -+ 1) // self.block_size ++ if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: ++ # Async loading. If num_computed_tokens is set it implies we ++ # already processed some block failures for it in a prior step ++ req_num_computed_tokens = ( ++ request.num_computed_tokens if req_id ++ in self.failed_recving_kv_req_ids else len(req_block_ids) * ++ self.block_size) + else: -+ req_num_computed_blocks = len(req_block_ids) ++ # Sync loading. num_computed_tokens includes new tokens ++ req_num_computed_tokens = request.num_cached_tokens + ++ req_num_computed_blocks = (req_num_computed_tokens + ++ self.block_size - 1) // self.block_size + for idx, block_id in zip(range(req_num_computed_blocks), + req_block_ids): + @@ -562,6 +692,8 @@ index fe552db74..6a9d4b4b9 100644 + # and was already marked for recomputation. + # This means this request can still consider this block + # as computed when rescheduled. ++ # Currently this only applies to sync loading; Async ++ # loading does not yet support block sharing + continue + + marked_invalid_block_ids.add(block_id) @@ -572,23 +704,26 @@ index fe552db74..6a9d4b4b9 100644 + continue + + marked_invalid_block = True -+ num_tokens_to_reschedule += request.num_computed_tokens ++ # Truncate the computed tokens at the first failed block + request.num_computed_tokens = idx * self.block_size -+ num_tokens_to_reschedule -= request.num_computed_tokens ++ total_affected_tokens += (req_num_computed_tokens - ++ request.num_computed_tokens) + + if is_affected: + if not marked_invalid_block: + # All invalid blocks of this request are shared with + # previous requests and will be recomputed by them. + # Revert to considering only cached tokens as computed. -+ num_tokens_to_reschedule += (request.num_computed_tokens - -+ request.num_cached_tokens) ++ # Currently this only applies to sync loading; Async ++ # loading does not yet support block sharing ++ total_affected_tokens += (request.num_computed_tokens - ++ request.num_cached_tokens) + request.num_computed_tokens = request.num_cached_tokens + -+ affected_requests.add(request) ++ affected_req_ids.add(request.request_id) ++ ++ return (affected_req_ids, total_affected_tokens) + -+ return (affected_requests, num_tokens_to_reschedule, -+ marked_invalid_block_ids) + + def _handle_invalid_blocks(self, invalid_block_ids: set[int]) -> set[str]: + total_requests_to_reschedule = 0 @@ -598,47 +733,34 @@ index fe552db74..6a9d4b4b9 100644 + async_load_reqs = ( + req for req in self.waiting + if req.status == RequestStatus.WAITING_FOR_REMOTE_KVS) -+ (affected_requests, num_tokens_to_reschedule, -+ marked_invalid_block_ids) = ( -+ self._update_requests_with_invalid_blocks(async_load_reqs, -+ invalid_block_ids)) ++ async_affected_req_ids, num_tokens_to_reschedule = ( ++ self._update_requests_with_invalid_blocks(async_load_reqs, ++ invalid_block_ids)) + -+ total_requests_to_reschedule += len(affected_requests) ++ total_requests_to_reschedule += len(async_affected_req_ids) + total_tokens_to_reschedule += num_tokens_to_reschedule + -+ for request in affected_requests: -+ if request.num_computed_tokens: -+ # Cache any valid computed tokens. -+ self.kv_cache_manager.cache_blocks(request, -+ request.num_computed_tokens) -+ else: -+ # No valid computed tokens, release allocated blocks. -+ # There may be a local cache hit on retry. -+ self.kv_cache_manager.free(request) -+ -+ request.status = RequestStatus.WAITING -+ -+ # Remove async loaded invalid blocks already handled, -+ # as they cannot be shared with running requests. -+ invalid_block_ids.difference_update(marked_invalid_block_ids) ++ # Mark requests with async KV load failures; they will be rescheduled ++ # once loading completes ++ self.failed_recving_kv_req_ids |= async_affected_req_ids + + # --- Handle sync KV loads (running requests) --- -+ affected_requests, num_tokens_to_reschedule, _ = ( ++ sync_affected_req_ids, num_tokens_to_reschedule = ( + self._update_requests_with_invalid_blocks(self.running, + invalid_block_ids)) + -+ total_requests_to_reschedule += len(affected_requests) ++ total_requests_to_reschedule += len(sync_affected_req_ids) + total_tokens_to_reschedule += num_tokens_to_reschedule + + if total_requests_to_reschedule: -+ logger.info( ++ logger.warning( + "Recovered from KV load failure: " + "%d request(s) rescheduled (%d tokens affected).", + total_requests_to_reschedule, total_tokens_to_reschedule) + + # Return the IDs of affected running requests to skip in + # update_from_output. -+ return {r.request_id for r in affected_requests} ++ return sync_affected_req_ids diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index 5b4718038..28bd4618a 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py @@ -716,9 +838,18 @@ index b06b7cc80..61cd7110f 100644 def collective_rpc(self, method: Union[str, Callable], diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py -index f78623f57..c7b4100e3 100644 +index f78623f57..16af8dbce 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py +@@ -1,7 +1,7 @@ + # SPDX-License-Identifier: Apache-2.0 + # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +-from dataclasses import dataclass ++from dataclasses import dataclass, field + from typing import NamedTuple, Optional + + import torch @@ -107,6 +107,11 @@ class ModelRunnerOutput: # [req_ids] finished_sending: Optional[set[str]] = None @@ -727,7 +858,7 @@ index f78623f57..c7b4100e3 100644 + + # IDs of externally computed KV blocks that failed to load. + # Requests referencing these blocks should be rescheduled to recompute them. -+ invalid_block_ids: Optional[set[int]] = None ++ invalid_block_ids: set[int] = field(default_factory=set) # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None @@ -776,41 +907,55 @@ index 8f4e8d64c..f45e39f5c 100644 for i, block_table in enumerate(self.block_tables): block_table.add_row(block_ids[i], row_idx) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py -index 1a79d72be..0e65c98f5 100644 +index 1a79d72be..eb23c261a 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py -@@ -46,6 +46,11 @@ class CachedRequestState: - - def __post_init__(self): - self.num_prompt_tokens = len(self.prompt_token_ids) -+ # 'last_generator_offset' and 'last_gelen_last_output_token_ids' are -+ # used to allow safe rollback in case a sampled token turns out to be -+ # invalid (e.g., due to KV load errors). -+ self.last_generator_offset = 0 if self.generator else None -+ self.len_last_output_token_ids = len(self.output_token_ids) - - @property - def num_tokens(self) -> int: -@@ -201,6 +206,7 @@ class InputBatch: - # NOTE(woosuk): The indices of the requests that do not have their own - # generator should not be included in the dictionary. - self.generators: dict[int, torch.Generator] = {} -+ self.generators_last_offset: dict[int, int] = {} - - self.num_logprobs: dict[str, int] = {} - # NOTE(rob): num_prompt_logprobs only includes reqs -@@ -335,6 +341,9 @@ class InputBatch: - # do not have their own generator. - if request.generator is not None: - self.generators[req_index] = request.generator -+ assert (request.last_generator_offset is not None) -+ self.generators_last_offset[ -+ req_index] = request.last_generator_offset - - if sampling_params.logprobs is not None: - self.num_logprobs[req_id] = sampling_params.logprobs +@@ -96,6 +96,9 @@ class InputBatch: + pin_memory=False, + ) + self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() ++ self.is_token_ids = torch.zeros( ++ (max_num_reqs, max_model_len), device="cpu", dtype=bool, pin_memory=False ++ ) + self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) + self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) + self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) +@@ -286,8 +289,14 @@ class InputBatch: + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) ++ if request.prompt_token_ids is not None: ++ self.token_ids_cpu[req_index, :num_prompt_tokens] = request.prompt_token_ids ++ self.is_token_ids[req_index, :num_prompt_tokens] = True ++ else: ++ self.is_token_ids[req_index, :num_prompt_tokens] = False + self.token_ids_cpu[req_index, + start_idx:end_idx] = request.output_token_ids ++ self.is_token_ids[req_index, start_idx:end_idx] = True + # Number of token ids in token_ids_cpu. + # NOTE(woosuk): This may include spec decode tokens. + self.num_tokens[req_index] = request.num_tokens +@@ -472,6 +481,8 @@ class InputBatch: + tmp = self.token_ids_cpu[i1, ...].copy() + self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] + self.token_ids_cpu[i2, ...] = tmp ++ ++ self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] + + swap_dict_values(self.generators, i1, i2) + swap_dict_values(self.bad_words_token_ids, i1, i2) +@@ -542,6 +553,9 @@ class InputBatch: + num_tokens = self.num_tokens[last_req_index] + self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ + last_req_index, :num_tokens] ++ self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ ++ last_req_index, :num_tokens ++ ] + self.num_tokens[empty_index] = num_tokens + self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ + last_req_index] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py -index 5a26e88db..17b3d1c79 100644 +index 5a26e88db..e1c9252a4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -72,6 +72,9 @@ from ..sample.logits_processor import LogitsProcessorManager @@ -831,7 +976,7 @@ index 5a26e88db..17b3d1c79 100644 self.requests.pop(req_id, None) self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. -@@ -468,13 +472,33 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -468,11 +472,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the states of the running/resumed requests. is_last_rank = get_pp_group().is_last_rank req_data = scheduler_output.scheduled_cached_reqs @@ -841,40 +986,31 @@ index 5a26e88db..17b3d1c79 100644 num_computed_tokens = req_data.num_computed_tokens[i] new_block_ids = req_data.new_block_ids[i] resumed_from_preemption = req_data.resumed_from_preemption[i] ++ num_output_tokens = req_data.num_output_tokens[i] + is_sparsed_request = req_sparsed_slots[req_id] != INVALID_SLOT # Update the cached states. -+ if (num_computed_tokens <= req_state.num_computed_tokens): -+ # The request was rescheduled after a KV load failure. Clear -+ # the last sampled tokens and rewind the generator state -+ len_output_token_ids = len(req_state.output_token_ids) -+ del req_state.output_token_ids[req_state. -+ len_last_output_token_ids:] -+ if req_state.generator: -+ req_state.generator.set_offset( -+ req_state.last_generator_offset) + req_state.num_computed_tokens = num_computed_tokens +@@ -492,17 +499,32 @@ class GPUModelRunner(LoRAModelRunnerMixin): + elif num_new_tokens > 0: + req_state.output_token_ids.extend( + new_token_ids[-num_new_tokens:]) ++ elif num_output_tokens < len(req_state.output_token_ids): ++ # Some output tokens were discarded due to a sync-KV-load ++ # failure. Align the cached state. ++ del req_state.output_token_ids[num_output_tokens:] ++ + req_index = self.input_batch.req_id_to_index.get(req_id) + if req_index is not None: -+ len_last_sampled = (len_output_token_ids - -+ req_state.len_last_output_token_ids) -+ end_idx = self.input_batch.num_tokens_no_spec[ -+ req_index] - len_last_sampled ++ old_end_idx = self.input_batch.num_tokens_no_spec[ ++ req_index] ++ end_idx = self.input_batch.num_prompt_tokens[ ++ req_index] + num_output_tokens + self.input_batch.num_tokens[req_index] = end_idx + self.input_batch.num_tokens_no_spec[req_index] = end_idx -+ - req_state.num_computed_tokens = num_computed_tokens ++ self.input_batch.is_token_ids[req_index, ++ end_idx:old_end_idx] = False - if not is_last_rank: -@@ -493,16 +517,22 @@ class GPUModelRunner(LoRAModelRunnerMixin): - req_state.output_token_ids.extend( - new_token_ids[-num_new_tokens:]) - -+ req_state.len_last_output_token_ids = len( -+ req_state.output_token_ids) -+ if req_state.generator: -+ req_state.last_generator_offset = ( -+ req_state.generator.get_offset()) -+ # Update the block IDs. - if not resumed_from_preemption: - # Append the new blocks to the existing block IDs. @@ -894,15 +1030,7 @@ index 5a26e88db..17b3d1c79 100644 req_index = self.input_batch.req_id_to_index.get(req_id) if req_index is None: -@@ -512,9 +542,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): - req_ids_to_add.append(req_id) - continue - -+ if req_state.generator: -+ assert (req_state.last_generator_offset is not None) -+ self.input_batch.generators_last_offset[ -+ req_index] = req_state.last_generator_offset -+ +@@ -515,6 +537,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) @@ -911,7 +1039,7 @@ index 5a26e88db..17b3d1c79 100644 self.input_batch.block_table.append_row(new_block_ids, req_index) # For the last rank, we don't need to update the token_ids_cpu -@@ -623,6 +660,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -623,6 +647,19 @@ class GPUModelRunner(LoRAModelRunnerMixin): if self.uses_mrope: self._calc_mrope_positions(scheduler_output) @@ -931,7 +1059,7 @@ index 5a26e88db..17b3d1c79 100644 # Get token indices. # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] -@@ -652,11 +702,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -652,11 +689,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): # block_size. block_table_indices = ( req_indices * block_table.max_num_blocks_per_req + @@ -945,7 +1073,7 @@ index 5a26e88db..17b3d1c79 100644 np.add( block_numbers * block_size, block_offsets, -@@ -666,9 +716,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -666,9 +703,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.query_start_loc_np[0] = 0 self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens @@ -960,7 +1088,7 @@ index 5a26e88db..17b3d1c79 100644 # Copy the tensors to the GPU. self.input_ids[:total_num_scheduled_tokens].copy_( -@@ -680,6 +732,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -680,6 +719,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): non_blocking=True) else: # Common case (1D positions) @@ -969,7 +1097,7 @@ index 5a26e88db..17b3d1c79 100644 self.positions[:total_num_scheduled_tokens].copy_( self.positions_cpu[:total_num_scheduled_tokens], non_blocking=True) -@@ -1370,6 +1424,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1370,6 +1411,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): skip_cuda_graphs=skip_cuda_graphs, ): self.maybe_setup_kv_connector(scheduler_output) @@ -977,7 +1105,7 @@ index 5a26e88db..17b3d1c79 100644 model_output = self.model( input_ids=input_ids, -@@ -1378,9 +1433,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1378,9 +1420,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): inputs_embeds=inputs_embeds, ) @@ -991,17 +1119,7 @@ index 5a26e88db..17b3d1c79 100644 if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output -@@ -1474,7 +1532,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): - # This relies on cuda-specific torch-internal impl details - generator = self.input_batch.generators.get(i) - if generator is not None: -- generator.set_offset(generator.get_offset() - 4) -+ generator.set_offset( -+ self.input_batch.generators_last_offset.get(i)) - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) -@@ -1563,6 +1622,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1563,6 +1608,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): finished_sending=finished_sending, finished_recving=finished_recving, num_nans_in_logits=num_nans_in_logits, @@ -1010,7 +1128,7 @@ index 5a26e88db..17b3d1c79 100644 ) def propose_draft_token_ids( -@@ -1693,13 +1754,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1693,13 +1740,16 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.maybe_setup_kv_connector(scheduler_output) finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) @@ -1028,7 +1146,7 @@ index 5a26e88db..17b3d1c79 100644 return output @staticmethod -@@ -1719,9 +1783,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1719,9 +1769,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): kv_connector.start_load_kv(get_forward_context()) @staticmethod @@ -1059,7 +1177,7 @@ index 5a26e88db..17b3d1c79 100644 @staticmethod def get_finished_kv_transfers( -@@ -1732,6 +1815,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): +@@ -1732,6 +1801,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): scheduler_output.finished_req_ids) return None, None @@ -1072,7 +1190,7 @@ index 5a26e88db..17b3d1c79 100644 self, sampled_token_ids: list[list[int]], diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py -index 9e7e44d06..d52a49a2e 100644 +index 9e7e44d06..d9666d102 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,6 +1,7 @@ @@ -1141,19 +1259,4 @@ index 9e7e44d06..d52a49a2e 100644 def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): -- 2.50.1.windows.1 -diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py -index 6a9d4b4b9..ae06cf9eb 100644 ---- a/vllm/v1/core/sched/scheduler.py -+++ b/vllm/v1/core/sched/scheduler.py -@@ -830,9 +830,8 @@ class Scheduler(SchedulerInterface): - if model_runner_output.finished_dumping is not None: - request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) - is_prefill = request.num_output_tokens == 0 -- is_last_chunk = (num_tokens_scheduled + request.num_computed_tokens >= request.num_prompt_tokens) -- if is_prefill and is_last_chunk: -- self.connector.connector.commit(request.succeed_dumped_blocks, True) -+ if is_prefill: -+ self.connector.connector.commit(model_runner_output.finished_dumping.get(req_id, []), True) - - # Append generated tokens and check for stop. Note that if - # a request is still being prefilled, we expect the model runner + diff --git a/ucm/integration/vllm/uc_connector.py b/ucm/integration/vllm/uc_connector.py index af56cf7b..4fe0bd0d 100644 --- a/ucm/integration/vllm/uc_connector.py +++ b/ucm/integration/vllm/uc_connector.py @@ -489,6 +489,7 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: done_recving: set[str] = set() for req_id, tasks in self._need_load_reqs.items(): if req_id in self._load_failed_reqs: + done_recving.add(req_id) continue unfinished_tasks = [] for task in tasks: @@ -509,9 +510,10 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: ) self._load_failed_reqs.add(req_id) break - if not unfinished_tasks: - done_recving.add(req_id) - self._need_load_reqs[req_id] = unfinished_tasks + if unfinished_tasks: + self._need_load_reqs[req_id] = unfinished_tasks + continue + done_recving.add(req_id) # remove the finished requests for req_id in list(done_recving):