Skip to content

Commit 2a30c6e

Browse files
Run pre-commit
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
1 parent 0480fa3 commit 2a30c6e

File tree

2 files changed

+26
-38
lines changed

2 files changed

+26
-38
lines changed

vllm/config/scheduler.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111

1212
from vllm.config.utils import config
1313
from vllm.logger import init_logger
14-
from vllm.utils import (
15-
DEFAULT_MAX_NUM_BATCHED_TOKENS,
16-
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
17-
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
18-
)
14+
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
15+
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
16+
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS)
1917

2018
logger = init_logger(__name__)
2119

@@ -166,7 +164,8 @@ def compute_hash(self) -> str:
166164
# no factors to consider.
167165
# this config will not affect the computation graph.
168166
factors: list[Any] = []
169-
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
167+
hash_str = hashlib.md5(str(factors).encode(),
168+
usedforsecurity=False).hexdigest()
170169
return hash_str
171170

172171
def __post_init__(self, is_encoder_decoder: bool) -> None:
@@ -184,8 +183,7 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
184183
self.long_prefill_token_threshold = 0
185184
logger.info(
186185
"Encoder-decoder models do not support chunked prefill nor"
187-
" prefix caching; disabling both."
188-
)
186+
" prefix caching; disabling both.")
189187

190188
if self.max_num_batched_tokens is None:
191189
if self.enable_chunked_prefill:
@@ -195,8 +193,7 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
195193
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
196194
# for higher throughput.
197195
self.max_num_batched_tokens = max(
198-
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
199-
)
196+
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
200197

201198
if self.runner_type == "pooling":
202199
# Choose specific value for higher throughput
@@ -215,8 +212,8 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
215212
# Ensure max_num_batched_tokens does not exceed model limit.
216213
# Some models (e.g., Whisper) have embeddings tied to max length.
217214
self.max_num_batched_tokens = min(
218-
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
219-
)
215+
self.max_num_seqs * self.max_model_len,
216+
self.max_num_batched_tokens)
220217

221218
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
222219
self.encoder_cache_size = self.max_num_batched_tokens
@@ -230,7 +227,8 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
230227
self.chunked_prefill_enabled = self.enable_chunked_prefill
231228
if self.max_num_partial_prefills > 1:
232229
if self.long_prefill_token_threshold == 0:
233-
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
230+
self.long_prefill_token_threshold = int(self.max_model_len *
231+
0.04)
234232

235233
logger.info(
236234
"Concurrent partial prefills enabled with "
@@ -249,29 +247,26 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
249247
self.cuda_graph_sizes = [min(self.max_num_seqs * 2, 512)]
250248

251249
if self.async_scheduling:
252-
self.scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler"
250+
self.scheduler_cls = (
251+
"vllm.v1.core.sched.async_scheduler.AsyncScheduler")
253252

254253
@model_validator(mode="after")
255254
def _verify_args(self) -> Self:
256-
if (
257-
self.max_num_batched_tokens < self.max_model_len
258-
and not self.chunked_prefill_enabled
259-
):
255+
if (self.max_num_batched_tokens < self.max_model_len
256+
and not self.chunked_prefill_enabled):
260257
raise ValueError(
261258
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
262259
f"smaller than max_model_len ({self.max_model_len}). "
263260
"This effectively limits the maximum sequence length to "
264261
"max_num_batched_tokens and makes vLLM reject longer "
265262
"sequences. Please increase max_num_batched_tokens or "
266-
"decrease max_model_len."
267-
)
263+
"decrease max_model_len.")
268264

269265
if self.max_num_batched_tokens < self.max_num_seqs:
270266
raise ValueError(
271267
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
272268
"be greater than or equal to max_num_seqs "
273-
f"({self.max_num_seqs})."
274-
)
269+
f"({self.max_num_seqs}).")
275270

276271
if self.max_num_batched_tokens > self.max_num_seqs * self.max_model_len:
277272
logger.warning(
@@ -285,35 +280,29 @@ def _verify_args(self) -> Self:
285280
raise ValueError(
286281
"num_lookahead_slots "
287282
f"({self.num_lookahead_slots}) must be greater than or "
288-
"equal to 0."
289-
)
283+
"equal to 0.")
290284

291285
if self.max_num_partial_prefills < 1:
292286
raise ValueError(
293287
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
294-
"must be greater than or equal to 1."
295-
)
288+
"must be greater than or equal to 1.")
296289
elif self.max_num_partial_prefills > 1:
297290
if not self.chunked_prefill_enabled:
298-
raise ValueError(
299-
"Chunked prefill must be enabled to set "
300-
"max_num_partial_prefills > 1."
301-
)
291+
raise ValueError("Chunked prefill must be enabled to set "
292+
"max_num_partial_prefills > 1.")
302293

303294
if self.long_prefill_token_threshold > self.max_model_len:
304295
raise ValueError(
305296
"long_prefill_token_threshold "
306297
f"({self.long_prefill_token_threshold}) cannot be greater "
307-
f"than the max_model_len ({self.max_model_len})."
308-
)
298+
f"than the max_model_len ({self.max_model_len}).")
309299

310-
if (self.max_long_partial_prefills < 1) or (
311-
self.max_long_partial_prefills > self.max_num_partial_prefills
312-
):
300+
if (self.max_long_partial_prefills
301+
< 1) or (self.max_long_partial_prefills
302+
> self.max_num_partial_prefills):
313303
raise ValueError(
314304
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
315305
"must be greater than or equal to 1 and less than or equal to "
316-
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
317-
)
306+
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
318307

319308
return self

vllm/engine/arg_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,6 @@ def create_engine_config(
13681368
disable_chunked_mm_input=self.disable_chunked_mm_input,
13691369
is_multimodal_model=model_config.is_multimodal_model,
13701370
is_encoder_decoder=model_config.is_encoder_decoder,
1371-
preemption_mode=self.preemption_mode,
13721371
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
13731372
and parallel_config.use_ray),
13741373
policy=self.scheduling_policy,

0 commit comments

Comments
 (0)