diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 1aa34ee38602..1c54914d182b 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,26 +1,20 @@ # SPDX-License-Identifier: Apache-2.0 from collections import deque -from dataclasses import dataclass -from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, - Tuple, Union) +from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) from vllm.v1.core.kv_cache_manager import KVCacheManager +from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData, + SchedulerOutput) from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus -if TYPE_CHECKING: - from vllm.multimodal import MultiModalKwargs - from vllm.multimodal.base import PlaceholderRange - logger = init_logger(__name__) @@ -600,80 +594,3 @@ def make_stats(self) -> SchedulerStats: num_waiting_reqs=len(self.waiting), gpu_cache_usage=self.kv_cache_manager.usage, ) - - -@dataclass -class NewRequestData: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List["MultiModalKwargs"] - mm_hashes: List[str] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - block_ids: List[int] - num_computed_tokens: int - lora_request: Optional[LoRARequest] - - @classmethod - def from_request( - cls, - request: Request, - block_ids: List[int], - num_computed_tokens: int, - ) -> "NewRequestData": - return cls( - req_id=request.request_id, - prompt_token_ids=request.prompt_token_ids, - prompt=request.prompt, - mm_inputs=request.mm_inputs, - mm_hashes=request.mm_hashes, - mm_positions=request.mm_positions, - sampling_params=request.sampling_params, - block_ids=block_ids, - num_computed_tokens=num_computed_tokens, - lora_request=request.lora_request, - ) - - -@dataclass -class CachedRequestData: - - req_id: str - # If resumed_from_preemption is False, new_block_ids will be appended to - # the request's block IDs. If True, new_block_ids will be used as the - # request's block IDs instead of appending to the existing block IDs. - resumed_from_preemption: bool - new_block_ids: List[int] - num_computed_tokens: int - - @classmethod - def from_request( - cls, - request: Request, - resumed_from_preemption: bool, - new_block_ids: List[int], - num_computed_tokens: int, - ) -> "CachedRequestData": - return cls( - req_id=request.request_id, - resumed_from_preemption=resumed_from_preemption, - new_block_ids=new_block_ids, - num_computed_tokens=num_computed_tokens, - ) - - -@dataclass -class SchedulerOutput: - - scheduled_new_reqs: List[NewRequestData] - scheduled_cached_reqs: List[CachedRequestData] - - num_scheduled_tokens: Dict[str, int] - total_num_scheduled_tokens: int - scheduled_encoder_inputs: Dict[str, List[int]] - num_common_prefix_blocks: int - - finished_req_ids: Set[str] - free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/core/scheduler_output.py b/vllm/v1/core/scheduler_output.py new file mode 100644 index 000000000000..990b3dd0ed78 --- /dev/null +++ b/vllm/v1/core/scheduler_output.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple + +if TYPE_CHECKING: + from vllm.lora.request import LoRARequest + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.base import PlaceholderRange + from vllm.sampling_params import SamplingParams + from vllm.v1.request import Request + + +@dataclass +class NewRequestData: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List["MultiModalKwargs"] + mm_hashes: List[str] + mm_positions: List["PlaceholderRange"] + sampling_params: "SamplingParams" + block_ids: List[int] + num_computed_tokens: int + lora_request: Optional["LoRARequest"] + + @classmethod + def from_request( + cls, + request: "Request", + block_ids: List[int], + num_computed_tokens: int, + ) -> "NewRequestData": + return cls( + req_id=request.request_id, + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + mm_inputs=request.mm_inputs, + mm_hashes=request.mm_hashes, + mm_positions=request.mm_positions, + sampling_params=request.sampling_params, + block_ids=block_ids, + num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, + ) + + +@dataclass +class CachedRequestData: + + req_id: str + # If resumed_from_preemption is False, new_block_ids will be appended to + # the request's block IDs. If True, new_block_ids will be used as the + # request's block IDs instead of appending to the existing block IDs. + resumed_from_preemption: bool + new_block_ids: List[int] + num_computed_tokens: int + + @classmethod + def from_request( + cls, + request: "Request", + resumed_from_preemption: bool, + new_block_ids: List[int], + num_computed_tokens: int, + ) -> "CachedRequestData": + return cls( + req_id=request.request_id, + resumed_from_preemption=resumed_from_preemption, + new_block_ids=new_block_ids, + num_computed_tokens=num_computed_tokens, + ) + + +@dataclass +class SchedulerOutput: + + # List of the requests that are scheduled for the first time. + # We cache the request's data in each worker process, so that we don't + # need to re-send it every scheduling step. + scheduled_new_reqs: List[NewRequestData] + # List of the requests that have been scheduled before. + # Since the request's data is already cached in the worker processes, + # we only send the diff to minimize the communication cost. + scheduled_cached_reqs: List[CachedRequestData] + + # req_id -> num_scheduled_tokens + # Number of tokens scheduled for each request. + num_scheduled_tokens: Dict[str, int] + # Total number of tokens scheduled for all requests. + # Equal to sum(num_scheduled_tokens.values()) + total_num_scheduled_tokens: int + # req_id -> encoder input indices that need processing. + # E.g., if a request has [0, 1], it could mean the vision encoder needs + # to process that the request's 0-th and 1-th images in the current step. + scheduled_encoder_inputs: Dict[str, List[int]] + # Number of common prefix blocks for all requests. + # This can be used for cascade attention. + num_common_prefix_blocks: int + + # Request IDs that are finished in between the previous and the current + # steps. This is used to notify the workers about the finished requests + # so that they can free the cached states for those requests. + finished_req_ids: Set[str] + # List of (req_id, encoder_input_index) tuples. + # Used to free the encoder cache. + free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fdbca70bda71..9b1eab613bf7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -36,7 +36,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.core.scheduler_output import SchedulerOutput logger = init_logger(__name__) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0adb69073397..ad53f90b8665 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -18,7 +18,6 @@ from vllm.model_executor import set_random_seed from vllm.platforms import current_platform from vllm.utils import GiB_bytes -from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -26,7 +25,7 @@ logger = init_logger(__name__) if TYPE_CHECKING: - from vllm.v1.core.scheduler import SchedulerOutput + from vllm.v1.core.scheduler_output import SchedulerOutput class Worker: