22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44import hashlib
5- from dataclasses import field
5+ from dataclasses import InitVar , field
66from typing import Any , Literal , Union
77
88from pydantic import SkipValidation , model_validator
1111
1212from vllm .config .utils import config
1313from vllm .logger import init_logger
14- from vllm .utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS ,
15- MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS ,
16- POOLING_MODEL_MAX_NUM_BATCHED_TOKENS )
14+ from vllm .utils import (
15+ DEFAULT_MAX_NUM_BATCHED_TOKENS ,
16+ MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS ,
17+ POOLING_MODEL_MAX_NUM_BATCHED_TOKENS ,
18+ )
1719
1820logger = init_logger (__name__ )
1921
@@ -84,8 +86,12 @@ class SchedulerConfig:
8486 is_multimodal_model : bool = False
8587 """True if the model is multimodal."""
8688
87- is_encoder_decoder : bool = False
88- """True if the model is an encoder-decoder model."""
89+ is_encoder_decoder : InitVar [bool ] = False
90+ """True if the model is an encoder-decoder model.
91+
92+ Note: This is stored in the ModelConfig, and is used only here to
93+ disable chunked prefill and prefix caching for encoder-decoder models.
94+ """
8995
9096 # TODO (ywang96): Make this configurable.
9197 max_num_encoder_input_tokens : int = field (init = False )
@@ -160,26 +166,26 @@ def compute_hash(self) -> str:
160166 # no factors to consider.
161167 # this config will not affect the computation graph.
162168 factors : list [Any ] = []
163- hash_str = hashlib .md5 (str (factors ).encode (),
164- usedforsecurity = False ).hexdigest ()
169+ hash_str = hashlib .md5 (str (factors ).encode (), usedforsecurity = False ).hexdigest ()
165170 return hash_str
166171
167- def __post_init__ (self ) -> None :
172+ def __post_init__ (self , is_encoder_decoder : bool ) -> None :
168173 if self .max_model_len is None :
169174 self .max_model_len = 8192
170175
171176 if self .max_num_seqs is None :
172177 self .max_num_seqs = 128
173178
174- if self . is_encoder_decoder :
179+ if is_encoder_decoder :
175180 # Chunked prefill should be disabled for encoder-decoder models.
176181 self .disable_chunked_mm_input = True
177182 self .chunked_prefill_enabled = False
178183 self .enable_chunked_prefill = False
179184 self .long_prefill_token_threshold = 0
180185 logger .info (
181186 "Encoder-decoder models do not support chunked prefill nor"
182- " prefix caching; disabling both." )
187+ " prefix caching; disabling both."
188+ )
183189
184190 if self .max_num_batched_tokens is None :
185191 if self .enable_chunked_prefill :
@@ -189,7 +195,8 @@ def __post_init__(self) -> None:
189195 # DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
190196 # for higher throughput.
191197 self .max_num_batched_tokens = max (
192- self .max_model_len , DEFAULT_MAX_NUM_BATCHED_TOKENS )
198+ self .max_model_len , DEFAULT_MAX_NUM_BATCHED_TOKENS
199+ )
193200
194201 if self .runner_type == "pooling" :
195202 # Choose specific value for higher throughput
@@ -208,29 +215,31 @@ def __post_init__(self) -> None:
208215 # Ensure max_num_batched_tokens does not exceed model limit.
209216 # Some models (e.g., Whisper) have embeddings tied to max length.
210217 self .max_num_batched_tokens = min (
211- self .max_num_seqs * self .max_model_len ,
212- self . max_num_batched_tokens )
218+ self .max_num_seqs * self .max_model_len , self . max_num_batched_tokens
219+ )
213220
214221 self .max_num_encoder_input_tokens = self .max_num_batched_tokens
215222 self .encoder_cache_size = self .max_num_batched_tokens
216223
217224 if self .enable_chunked_prefill :
218225 logger .info (
219226 "Chunked prefill is enabled with max_num_batched_tokens=%d." ,
220- self .max_num_batched_tokens )
227+ self .max_num_batched_tokens ,
228+ )
221229
222230 self .chunked_prefill_enabled = self .enable_chunked_prefill
223231 if self .max_num_partial_prefills > 1 :
224232 if self .long_prefill_token_threshold == 0 :
225- self .long_prefill_token_threshold = int (self .max_model_len *
226- 0.04 )
233+ self .long_prefill_token_threshold = int (self .max_model_len * 0.04 )
227234
228235 logger .info (
229236 "Concurrent partial prefills enabled with "
230237 "max_num_partial_prefills=%d, max_long_partial_prefills=%d, "
231238 "long_prefill_token_threshold=%d" ,
232- self .max_num_partial_prefills , self .max_long_partial_prefills ,
233- self .long_prefill_token_threshold )
239+ self .max_num_partial_prefills ,
240+ self .max_long_partial_prefills ,
241+ self .long_prefill_token_threshold ,
242+ )
234243
235244 # NOTE: Default set cuda_graph_sizes to [min(max_num_seqs * 2, 512)].
236245 # This avoids OOM in tight memory scenarios with small max_num_seqs,
@@ -240,61 +249,71 @@ def __post_init__(self) -> None:
240249 self .cuda_graph_sizes = [min (self .max_num_seqs * 2 , 512 )]
241250
242251 if self .async_scheduling :
243- self .scheduler_cls = (
244- "vllm.v1.core.sched.async_scheduler.AsyncScheduler" )
252+ self .scheduler_cls = "vllm.v1.core.sched.async_scheduler.AsyncScheduler"
245253
246- @model_validator (mode = ' after' )
254+ @model_validator (mode = " after" )
247255 def _verify_args (self ) -> Self :
248- if (self .max_num_batched_tokens < self .max_model_len
249- and not self .chunked_prefill_enabled ):
256+ if (
257+ self .max_num_batched_tokens < self .max_model_len
258+ and not self .chunked_prefill_enabled
259+ ):
250260 raise ValueError (
251261 f"max_num_batched_tokens ({ self .max_num_batched_tokens } ) is "
252262 f"smaller than max_model_len ({ self .max_model_len } ). "
253263 "This effectively limits the maximum sequence length to "
254264 "max_num_batched_tokens and makes vLLM reject longer "
255265 "sequences. Please increase max_num_batched_tokens or "
256- "decrease max_model_len." )
266+ "decrease max_model_len."
267+ )
257268
258269 if self .max_num_batched_tokens < self .max_num_seqs :
259270 raise ValueError (
260271 f"max_num_batched_tokens ({ self .max_num_batched_tokens } ) must "
261272 "be greater than or equal to max_num_seqs "
262- f"({ self .max_num_seqs } )." )
273+ f"({ self .max_num_seqs } )."
274+ )
263275
264276 if self .max_num_batched_tokens > self .max_num_seqs * self .max_model_len :
265277 logger .warning (
266278 "max_num_batched_tokens (%d) exceeds max_num_seqs "
267279 "* max_model_len (%d). This may lead to unexpected behavior." ,
268280 self .max_num_batched_tokens ,
269- self .max_num_seqs * self .max_model_len )
281+ self .max_num_seqs * self .max_model_len ,
282+ )
270283
271284 if self .num_lookahead_slots < 0 :
272285 raise ValueError (
273286 "num_lookahead_slots "
274287 f"({ self .num_lookahead_slots } ) must be greater than or "
275- "equal to 0." )
288+ "equal to 0."
289+ )
276290
277291 if self .max_num_partial_prefills < 1 :
278292 raise ValueError (
279293 f"max_num_partial_prefills ({ self .max_num_partial_prefills } ) "
280- "must be greater than or equal to 1." )
294+ "must be greater than or equal to 1."
295+ )
281296 elif self .max_num_partial_prefills > 1 :
282297 if not self .chunked_prefill_enabled :
283- raise ValueError ("Chunked prefill must be enabled to set "
284- "max_num_partial_prefills > 1." )
298+ raise ValueError (
299+ "Chunked prefill must be enabled to set "
300+ "max_num_partial_prefills > 1."
301+ )
285302
286303 if self .long_prefill_token_threshold > self .max_model_len :
287304 raise ValueError (
288305 "long_prefill_token_threshold "
289306 f"({ self .long_prefill_token_threshold } ) cannot be greater "
290- f"than the max_model_len ({ self .max_model_len } )." )
307+ f"than the max_model_len ({ self .max_model_len } )."
308+ )
291309
292- if (self .max_long_partial_prefills
293- < 1 ) or ( self .max_long_partial_prefills
294- > self . max_num_partial_prefills ):
310+ if (self .max_long_partial_prefills < 1 ) or (
311+ self . max_long_partial_prefills > self .max_num_partial_prefills
312+ ):
295313 raise ValueError (
296314 f"max_long_partial_prefills ({ self .max_long_partial_prefills } ) "
297315 "must be greater than or equal to 1 and less than or equal to "
298- f"max_num_partial_prefills ({ self .max_num_partial_prefills } )." )
316+ f"max_num_partial_prefills ({ self .max_num_partial_prefills } )."
317+ )
299318
300319 return self
0 commit comments