Skip to content

Commit 0480fa3

Browse files
Make is_encoder_decoder an init var
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
1 parent b1a5140 commit 0480fa3

File tree

2 files changed

+57
-36
lines changed

2 files changed

+57
-36
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,6 +1921,8 @@ def test_chunked_prefill_disabled_for_encoder_decoder(
19211921
is_encoder_decoder=is_encoder_decoder,
19221922
)
19231923

1924+
# `is_encoder_decoder` should only be used during construction of the config
1925+
assert not hasattr(scheduler_config, "is_encoder_decoder")
19241926
_validate_chunked_prefill_settings_for_encoder_decoder(
19251927
scheduler_config, is_encoder_decoder, expect_enabled)
19261928

vllm/config/scheduler.py

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import hashlib
5-
from dataclasses import field
5+
from dataclasses import InitVar, field
66
from typing import Any, Literal, Union
77

88
from pydantic import SkipValidation, model_validator
@@ -11,9 +11,11 @@
1111

1212
from vllm.config.utils import config
1313
from vllm.logger import init_logger
14-
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
15-
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
16-
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS)
14+
from vllm.utils import (
15+
DEFAULT_MAX_NUM_BATCHED_TOKENS,
16+
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
17+
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
18+
)
1719

1820
logger = init_logger(__name__)
1921

@@ -84,8 +86,12 @@ class SchedulerConfig:
8486
is_multimodal_model: bool = False
8587
"""True if the model is multimodal."""
8688

87-
is_encoder_decoder: bool = False
88-
"""True if the model is an encoder-decoder model."""
89+
is_encoder_decoder: InitVar[bool] = False
90+
"""True if the model is an encoder-decoder model.
91+
92+
Note: This is stored in the ModelConfig, and is used only here to
93+
disable chunked prefill and prefix caching for encoder-decoder models.
94+
"""
8995

9096
# TODO (ywang96): Make this configurable.
9197
max_num_encoder_input_tokens: int = field(init=False)
@@ -160,26 +166,26 @@ def compute_hash(self) -> str:
160166
# no factors to consider.
161167
# this config will not affect the computation graph.
162168
factors: list[Any] = []
163-
hash_str = hashlib.md5(str(factors).encode(),
164-
usedforsecurity=False).hexdigest()
169+
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
165170
return hash_str
166171

167-
def __post_init__(self) -> None:
172+
def __post_init__(self, is_encoder_decoder: bool) -> None:
168173
if self.max_model_len is None:
169174
self.max_model_len = 8192
170175

171176
if self.max_num_seqs is None:
172177
self.max_num_seqs = 128
173178

174-
if self.is_encoder_decoder:
179+
if is_encoder_decoder:
175180
# Chunked prefill should be disabled for encoder-decoder models.
176181
self.disable_chunked_mm_input = True
177182
self.chunked_prefill_enabled = False
178183
self.enable_chunked_prefill = False
179184
self.long_prefill_token_threshold = 0
180185
logger.info(
181186
"Encoder-decoder models do not support chunked prefill nor"
182-
" prefix caching; disabling both.")
187+
" prefix caching; disabling both."
188+
)
183189

184190
if self.max_num_batched_tokens is None:
185191
if self.enable_chunked_prefill:
@@ -189,7 +195,8 @@ def __post_init__(self) -> None:
189195
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
190196
# for higher throughput.
191197
self.max_num_batched_tokens = max(
192-
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
198+
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
199+
)
193200

194201
if self.runner_type == "pooling":
195202
# Choose specific value for higher throughput
@@ -208,29 +215,31 @@ def __post_init__(self) -> None:
208215
# Ensure max_num_batched_tokens does not exceed model limit.
209216
# Some models (e.g., Whisper) have embeddings tied to max length.
210217
self.max_num_batched_tokens = min(
211-
self.max_num_seqs * self.max_model_len,
212-
self.max_num_batched_tokens)
218+
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
219+
)
213220

214221
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
215222
self.encoder_cache_size = self.max_num_batched_tokens
216223

217224
if self.enable_chunked_prefill:
218225
logger.info(
219226
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
220-
self.max_num_batched_tokens)
227+
self.max_num_batched_tokens,
228+
)
221229

222230
self.chunked_prefill_enabled = self.enable_chunked_prefill
223231
if self.max_num_partial_prefills > 1:
224232
if self.long_prefill_token_threshold == 0:
225-
self.long_prefill_token_threshold = int(self.max_model_len *
226-
0.04)
233+
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
227234

228235
logger.info(
229236
"Concurrent partial prefills enabled with "
230237
"max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
231238
"long_prefill_token_threshold=%d",
232-
self.max_num_partial_prefills, self.max_long_partial_prefills,
233-
self.long_prefill_token_threshold)
239+
self.max_num_partial_prefills,
240+
self.max_long_partial_prefills,
241+
self.long_prefill_token_threshold,
242+
)
234243

235244
# NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)].
236245
# This avoids OOM in tight memory scenarios with small max_num_seqs,
@@ -240,61 +249,71 @@ def __post_init__(self) -> None:
240249
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
241250

242251
if self.async_scheduling:
243-
self.scheduler_cls = (
244-
"vllm.v1.core.sched.async_scheduler.AsyncScheduler")
252+
self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler"
245253

246-
@model_validator(mode='after')
254+
@model_validator(mode="after")
247255
def _verify_args(self) -> Self:
248-
if (self.max_num_batched_tokens < self.max_model_len
249-
and not self.chunked_prefill_enabled):
256+
if (
257+
self.max_num_batched_tokens < self.max_model_len
258+
and not self.chunked_prefill_enabled
259+
):
250260
raise ValueError(
251261
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
252262
f"smaller than max_model_len ({self.max_model_len}). "
253263
"This effectively limits the maximum sequence length to "
254264
"max_num_batched_tokens and makes vLLM reject longer "
255265
"sequences. Please increase max_num_batched_tokens or "
256-
"decrease max_model_len.")
266+
"decrease max_model_len."
267+
)
257268

258269
if self.max_num_batched_tokens < self.max_num_seqs:
259270
raise ValueError(
260271
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
261272
"be greater than or equal to max_num_seqs "
262-
f"({self.max_num_seqs}).")
273+
f"({self.max_num_seqs})."
274+
)
263275

264276
if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
265277
logger.warning(
266278
"max_num_batched_tokens (%d) exceeds max_num_seqs "
267279
"* max_model_len (%d). This may lead to unexpected behavior.",
268280
self.max_num_batched_tokens,
269-
self.max_num_seqs * self.max_model_len)
281+
self.max_num_seqs * self.max_model_len,
282+
)
270283

271284
if self.num_lookahead_slots < 0:
272285
raise ValueError(
273286
"num_lookahead_slots "
274287
f"({self.num_lookahead_slots}) must be greater than or "
275-
"equal to 0.")
288+
"equal to 0."
289+
)
276290

277291
if self.max_num_partial_prefills < 1:
278292
raise ValueError(
279293
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
280-
"must be greater than or equal to 1.")
294+
"must be greater than or equal to 1."
295+
)
281296
elif self.max_num_partial_prefills > 1:
282297
if not self.chunked_prefill_enabled:
283-
raise ValueError("Chunked prefill must be enabled to set "
284-
"max_num_partial_prefills > 1.")
298+
raise ValueError(
299+
"Chunked prefill must be enabled to set "
300+
"max_num_partial_prefills > 1."
301+
)
285302

286303
if self.long_prefill_token_threshold > self.max_model_len:
287304
raise ValueError(
288305
"long_prefill_token_threshold "
289306
f"({self.long_prefill_token_threshold}) cannot be greater "
290-
f"than the max_model_len ({self.max_model_len}).")
307+
f"than the max_model_len ({self.max_model_len})."
308+
)
291309

292-
if (self.max_long_partial_prefills
293-
< 1) or (self.max_long_partial_prefills
294-
> self.max_num_partial_prefills):
310+
if (self.max_long_partial_prefills < 1) or (
311+
self.max_long_partial_prefills > self.max_num_partial_prefills
312+
):
295313
raise ValueError(
296314
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
297315
"must be greater than or equal to 1 and less than or equal to "
298-
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
316+
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
317+
)
299318

300319
return self

0 commit comments

Comments
 (0)