diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 7d7cd0c94dd0..d4887937394b 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -125,3 +125,14 @@ class ModelRunnerOutput: prompt_logprobs_dict={}, pooler_output=[], num_nans_in_logits=None) + +EMPTY_MODEL_RUNNER_WITH_KVC_OUTPUT = ModelRunnerOutput( + req_ids=[], + req_id_to_index={}, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[], + num_nans_in_logits=None, + kv_connector_output=KVConnectorOutput()) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 7fca245c1bef..084d33af0527 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -27,7 +27,8 @@ from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_WITH_KVC_OUTPUT, + ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -377,9 +378,9 @@ def execute_model( # kv_connector_output if (not kv_connector_output.finished_sending and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT + return EMPTY_MODEL_RUNNER_WITH_KVC_OUTPUT - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output = copy.copy(EMPTY_MODEL_RUNNER_WITH_KVC_OUTPUT) output.kv_connector_output = kv_connector_output return output diff --git a/vllm/v1/worker/kv_connector_model_runner_mixin.py b/vllm/v1/worker/kv_connector_model_runner_mixin.py index a03ebe35d8e0..6ccd85e492d7 100644 --- a/vllm/v1/worker/kv_connector_model_runner_mixin.py +++ b/vllm/v1/worker/kv_connector_model_runner_mixin.py @@ -14,8 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.forward_context import get_forward_context, set_forward_context from vllm.logger import init_logger -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput, - ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_WITH_KVC_OUTPUT, + KVConnectorOutput, ModelRunnerOutput) if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput @@ -68,9 +68,9 @@ def kv_connector_no_forward(scheduler_output: "SchedulerOutput", if (not kv_connector_output.finished_sending and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT + return EMPTY_MODEL_RUNNER_WITH_KVC_OUTPUT - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output = copy.copy(EMPTY_MODEL_RUNNER_WITH_KVC_OUTPUT) output.kv_connector_output = kv_connector_output return output