diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fe552db74e2f..eb91598efe8c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -842,7 +842,6 @@ def update_from_output( spec_token_ids[req_index]) else: request.spec_token_ids = spec_token_ids[req_index] - # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None \ @@ -869,6 +868,10 @@ def update_from_output( if not stopped: new_running.append(request) + + if model_runner_output.finished_dumping is not None: + request.succeed_dumped_blocks.extend(model_runner_output.finished_dumping.get(req_id, [])) + self.running = new_running # KV Connector: update state for finished KV Transfers. diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f78623f571b2..c8388baed21f 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -107,6 +107,7 @@ class ModelRunnerOutput: # [req_ids] finished_sending: Optional[set[str]] = None finished_recving: Optional[set[str]] = None + finished_dumping: Optional[dict[str, list[str]]] = None # req_id -> num_nans_in_logits num_nans_in_logits: Optional[dict[str, int]] = None diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 9b96f4599f92..825b77bba1b6 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -102,7 +102,7 @@ def __init__( # State # The number of tokens with prefix cache hits. self.num_cached_tokens = -1 - + self.succeed_dumped_blocks: list[str] = [] # The number of NaNs in logits. A value greater than 0 # indicates that the output is corrupted self.num_nans_in_logits = 0 diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 5a26e88db1f7..14278bb6a7ee 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1378,7 +1378,7 @@ def execute_model( inputs_embeds=inputs_embeds, ) - self.maybe_wait_for_kv_save() + finished_dumping = self.maybe_wait_for_kv_save() finished_sending, finished_recving = ( self.get_finished_kv_transfers(scheduler_output)) @@ -1563,6 +1563,7 @@ def execute_model( finished_sending=finished_sending, finished_recving=finished_recving, num_nans_in_logits=num_nans_in_logits, + finished_dumping=finished_dumping ) def propose_draft_token_ids( @@ -1719,9 +1720,9 @@ def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): kv_connector.start_load_kv(get_forward_context()) @staticmethod - def maybe_wait_for_kv_save() -> None: + def maybe_wait_for_kv_save() -> Optional[dict[str, list[str]]]: if has_kv_transfer_group(): - get_kv_transfer_group().wait_for_save() + return get_kv_transfer_group().wait_for_save() @staticmethod def get_finished_kv_transfers(