Skip to content

Commit a04b109

Browse files
WoosukKwonkwang1012
authored andcommitted
[V1][Minor] Move scheduler outputs to a separate file (vllm-project#13062)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 8260f58 commit a04b109

File tree

4 files changed

+113
-89
lines changed

4 files changed

+113
-89
lines changed

vllm/v1/core/scheduler.py

Lines changed: 3 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from collections import deque
4-
from dataclasses import dataclass
5-
from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
6-
Tuple, Union)
4+
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
75

86
from vllm.config import CacheConfig, LoRAConfig, ModelConfig, SchedulerConfig
97
from vllm.logger import init_logger
10-
from vllm.lora.request import LoRARequest
11-
from vllm.sampling_params import SamplingParams
128
from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
139
compute_encoder_budget)
1410
from vllm.v1.core.kv_cache_manager import KVCacheManager
11+
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
12+
SchedulerOutput)
1513
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
1614
from vllm.v1.metrics.stats import SchedulerStats
1715
from vllm.v1.outputs import ModelRunnerOutput
1816
from vllm.v1.request import Request, RequestStatus
1917

20-
if TYPE_CHECKING:
21-
from vllm.multimodal import MultiModalKwargs
22-
from vllm.multimodal.base import PlaceholderRange
23-
2418
logger = init_logger(__name__)
2519

2620

@@ -600,80 +594,3 @@ def make_stats(self) -> SchedulerStats:
600594
num_waiting_reqs=len(self.waiting),
601595
gpu_cache_usage=self.kv_cache_manager.usage,
602596
)
603-
604-
605-
@dataclass
606-
class NewRequestData:
607-
608-
req_id: str
609-
prompt_token_ids: List[int]
610-
prompt: Optional[str]
611-
mm_inputs: List["MultiModalKwargs"]
612-
mm_hashes: List[str]
613-
mm_positions: List["PlaceholderRange"]
614-
sampling_params: SamplingParams
615-
block_ids: List[int]
616-
num_computed_tokens: int
617-
lora_request: Optional[LoRARequest]
618-
619-
@classmethod
620-
def from_request(
621-
cls,
622-
request: Request,
623-
block_ids: List[int],
624-
num_computed_tokens: int,
625-
) -> "NewRequestData":
626-
return cls(
627-
req_id=request.request_id,
628-
prompt_token_ids=request.prompt_token_ids,
629-
prompt=request.prompt,
630-
mm_inputs=request.mm_inputs,
631-
mm_hashes=request.mm_hashes,
632-
mm_positions=request.mm_positions,
633-
sampling_params=request.sampling_params,
634-
block_ids=block_ids,
635-
num_computed_tokens=num_computed_tokens,
636-
lora_request=request.lora_request,
637-
)
638-
639-
640-
@dataclass
641-
class CachedRequestData:
642-
643-
req_id: str
644-
# If resumed_from_preemption is False, new_block_ids will be appended to
645-
# the request's block IDs. If True, new_block_ids will be used as the
646-
# request's block IDs instead of appending to the existing block IDs.
647-
resumed_from_preemption: bool
648-
new_block_ids: List[int]
649-
num_computed_tokens: int
650-
651-
@classmethod
652-
def from_request(
653-
cls,
654-
request: Request,
655-
resumed_from_preemption: bool,
656-
new_block_ids: List[int],
657-
num_computed_tokens: int,
658-
) -> "CachedRequestData":
659-
return cls(
660-
req_id=request.request_id,
661-
resumed_from_preemption=resumed_from_preemption,
662-
new_block_ids=new_block_ids,
663-
num_computed_tokens=num_computed_tokens,
664-
)
665-
666-
667-
@dataclass
668-
class SchedulerOutput:
669-
670-
scheduled_new_reqs: List[NewRequestData]
671-
scheduled_cached_reqs: List[CachedRequestData]
672-
673-
num_scheduled_tokens: Dict[str, int]
674-
total_num_scheduled_tokens: int
675-
scheduled_encoder_inputs: Dict[str, List[int]]
676-
num_common_prefix_blocks: int
677-
678-
finished_req_ids: Set[str]
679-
free_encoder_input_ids: List[Tuple[str, int]]

vllm/v1/core/scheduler_output.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
5+
6+
if TYPE_CHECKING:
7+
from vllm.lora.request import LoRARequest
8+
from vllm.multimodal import MultiModalKwargs
9+
from vllm.multimodal.base import PlaceholderRange
10+
from vllm.sampling_params import SamplingParams
11+
from vllm.v1.request import Request
12+
13+
14+
@dataclass
15+
class NewRequestData:
16+
17+
req_id: str
18+
prompt_token_ids: List[int]
19+
prompt: Optional[str]
20+
mm_inputs: List["MultiModalKwargs"]
21+
mm_hashes: List[str]
22+
mm_positions: List["PlaceholderRange"]
23+
sampling_params: "SamplingParams"
24+
block_ids: List[int]
25+
num_computed_tokens: int
26+
lora_request: Optional["LoRARequest"]
27+
28+
@classmethod
29+
def from_request(
30+
cls,
31+
request: "Request",
32+
block_ids: List[int],
33+
num_computed_tokens: int,
34+
) -> "NewRequestData":
35+
return cls(
36+
req_id=request.request_id,
37+
prompt_token_ids=request.prompt_token_ids,
38+
prompt=request.prompt,
39+
mm_inputs=request.mm_inputs,
40+
mm_hashes=request.mm_hashes,
41+
mm_positions=request.mm_positions,
42+
sampling_params=request.sampling_params,
43+
block_ids=block_ids,
44+
num_computed_tokens=num_computed_tokens,
45+
lora_request=request.lora_request,
46+
)
47+
48+
49+
@dataclass
50+
class CachedRequestData:
51+
52+
req_id: str
53+
# If resumed_from_preemption is False, new_block_ids will be appended to
54+
# the request's block IDs. If True, new_block_ids will be used as the
55+
# request's block IDs instead of appending to the existing block IDs.
56+
resumed_from_preemption: bool
57+
new_block_ids: List[int]
58+
num_computed_tokens: int
59+
60+
@classmethod
61+
def from_request(
62+
cls,
63+
request: "Request",
64+
resumed_from_preemption: bool,
65+
new_block_ids: List[int],
66+
num_computed_tokens: int,
67+
) -> "CachedRequestData":
68+
return cls(
69+
req_id=request.request_id,
70+
resumed_from_preemption=resumed_from_preemption,
71+
new_block_ids=new_block_ids,
72+
num_computed_tokens=num_computed_tokens,
73+
)
74+
75+
76+
@dataclass
77+
class SchedulerOutput:
78+
79+
# List of the requests that are scheduled for the first time.
80+
# We cache the request's data in each worker process, so that we don't
81+
# need to re-send it every scheduling step.
82+
scheduled_new_reqs: List[NewRequestData]
83+
# List of the requests that have been scheduled before.
84+
# Since the request's data is already cached in the worker processes,
85+
# we only send the diff to minimize the communication cost.
86+
scheduled_cached_reqs: List[CachedRequestData]
87+
88+
# req_id -> num_scheduled_tokens
89+
# Number of tokens scheduled for each request.
90+
num_scheduled_tokens: Dict[str, int]
91+
# Total number of tokens scheduled for all requests.
92+
# Equal to sum(num_scheduled_tokens.values())
93+
total_num_scheduled_tokens: int
94+
# req_id -> encoder input indices that need processing.
95+
# E.g., if a request has [0, 1], it could mean the vision encoder needs
96+
# to process that the request's 0-th and 1-th images in the current step.
97+
scheduled_encoder_inputs: Dict[str, List[int]]
98+
# Number of common prefix blocks for all requests.
99+
# This can be used for cascade attention.
100+
num_common_prefix_blocks: int
101+
102+
# Request IDs that are finished in between the previous and the current
103+
# steps. This is used to notify the workers about the finished requests
104+
# so that they can free the cached states for those requests.
105+
finished_req_ids: Set[str]
106+
# List of (req_id, encoder_input_index) tuples.
107+
# Used to free the encoder cache.
108+
free_encoder_input_ids: List[Tuple[str, int]]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
3838

3939
if TYPE_CHECKING:
40-
from vllm.v1.core.scheduler import SchedulerOutput
40+
from vllm.v1.core.scheduler_output import SchedulerOutput
4141

4242
logger = init_logger(__name__)
4343

vllm/v1/worker/gpu_worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@
1818
from vllm.model_executor import set_random_seed
1919
from vllm.platforms import current_platform
2020
from vllm.utils import GiB_bytes
21-
from vllm.v1.core.scheduler import SchedulerOutput
2221
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2322
from vllm.v1.outputs import ModelRunnerOutput
2423
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
2524

2625
logger = init_logger(__name__)
2726

2827
if TYPE_CHECKING:
29-
from vllm.v1.core.scheduler import SchedulerOutput
28+
from vllm.v1.core.scheduler_output import SchedulerOutput
3029

3130

3231
class Worker:

0 commit comments

Comments
 (0)