Skip to content

Commit 0c6e40b

Browse files
[Refactor] Simplify code for MM budget (#23310)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent 2e2000f commit 0c6e40b

File tree

4 files changed

+57
-68
lines changed

4 files changed

+57
-68
lines changed

vllm/v1/core/encoder_cache_manager.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
3+
from collections.abc import Mapping
44
from typing import TYPE_CHECKING
55

66
from vllm.logger import init_logger
@@ -188,35 +188,47 @@ def compute_encoder_budget(
188188
- Space budget for encoder cache size, in unit of number of tokens
189189
in the input sequence.
190190
"""
191+
if mm_registry.supports_multimodal_inputs(model_config):
192+
max_tokens_by_modality = mm_registry \
193+
.get_max_tokens_per_item_by_nonzero_modality(model_config)
191194

192-
if not mm_registry.supports_multimodal_inputs(model_config):
193-
return 0, 0
195+
return compute_mm_encoder_budget(
196+
scheduler_config,
197+
max_tokens_by_modality,
198+
)
194199

195-
# TODO: handle encoder-decoder models once we support them.
196-
(
197-
encoder_compute_budget,
198-
encoder_cache_size,
199-
) = _compute_encoder_budget_multimodal(
200-
model_config,
201-
scheduler_config,
202-
mm_registry,
203-
)
200+
return compute_text_encoder_budget(scheduler_config)
204201

205-
return encoder_compute_budget, encoder_cache_size
206202

203+
def compute_text_encoder_budget(
204+
scheduler_config: "SchedulerConfig") -> tuple[int, int]:
205+
"""Compute the encoder cache budget based on the model and scheduler
206+
configurations for a text-only model.
207207
208-
def _compute_encoder_budget_multimodal(
209-
model_config: "ModelConfig",
208+
Args:
209+
scheduler_config: Scheduler configuration.
210+
211+
Returns:
212+
- Compute budget for encoder execution, in unit of number of tokens
213+
in the input sequence.
214+
- Space budget for encoder cache size, in unit of number of tokens
215+
in the input sequence.
216+
"""
217+
# Currently text-only encoder-decoder models are not supported
218+
return 0, 0
219+
220+
221+
def compute_mm_encoder_budget(
210222
scheduler_config: "SchedulerConfig",
211-
mm_registry: MultiModalRegistry,
223+
max_tokens_by_modality: Mapping[str, int],
212224
) -> tuple[int, int]:
213225
"""Compute the encoder cache budget based on the model and scheduler
214226
configurations for a multimodal model.
215227
216228
Args:
217-
model_config: Model configuration.
218229
scheduler_config: Scheduler configuration.
219-
mm_registry: Provides information about the token cost.
230+
max_tokens_by_modality: The maximum number of tokens for each
231+
non-text modality.
220232
221233
Returns:
222234
- Compute budget for encoder execution, in unit of number of tokens
@@ -225,18 +237,14 @@ def _compute_encoder_budget_multimodal(
225237
in the input sequence.
226238
"""
227239

228-
max_tokens_by_modality_dict = mm_registry \
229-
.get_max_tokens_per_item_by_nonzero_modality(model_config)
230-
231-
if not max_tokens_by_modality_dict:
240+
if not max_tokens_by_modality:
232241
logger.warning(
233242
"All non-text modalities supported by the model have been "
234243
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
235244
"not be initialized.")
236245
return 0, 0
237246

238-
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
239-
key=lambda item: item[1])
247+
max_tokens_per_mm_item = max(max_tokens_by_modality.values())
240248

241249
if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item
242250
> scheduler_config.max_num_batched_tokens):

vllm/v1/worker/gpu_model_runner.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,7 @@ def __init__(
341341
self.model_config,
342342
self.scheduler_config,
343343
self.mm_registry,
344-
max_model_len=self.max_model_len,
345-
max_num_reqs=self.max_num_reqs,
346-
) if self.supports_mm_inputs \
347-
else None)
344+
) if self.supports_mm_inputs else None)
348345

349346
self.reorder_batch_threshold: Optional[int] = None
350347

@@ -669,7 +666,7 @@ def _dummy_mm_kwargs(self, num_seqs: int) -> BatchedTensorInputs:
669666
mm_budget = self.mm_budget
670667
assert mm_budget is not None
671668

672-
dummy_modality, _ = mm_budget.get_modality_with_max_tokens()
669+
dummy_modality = mm_budget.get_modality_with_max_tokens()
673670

674671
return self._get_mm_dummy_batch(dummy_modality, num_seqs)
675672

@@ -2595,14 +2592,9 @@ def profile_run(self) -> None:
25952592
# NOTE: Currently model is profiled with a single non-text
25962593
# modality with the max possible input tokens even when
25972594
# it supports multiple.
2598-
(
2599-
dummy_modality,
2600-
max_tokens,
2601-
) = mm_budget.get_modality_with_max_tokens()
2602-
(
2603-
max_mm_items_per_prompt,
2604-
max_mm_items_per_batch,
2605-
) = mm_budget.get_max_items(dummy_modality, max_tokens)
2595+
dummy_modality = mm_budget.get_modality_with_max_tokens()
2596+
max_mm_items_per_batch = mm_budget \
2597+
.max_items_per_batch_by_modality[dummy_modality]
26062598

26072599
logger.info(
26082600
"Encoder cache will be initialized with a budget of "

vllm/v1/worker/tpu_model_runner.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,6 @@ def __init__(
292292
self.model_config,
293293
self.scheduler_config,
294294
self.mm_registry,
295-
max_model_len=self.max_model_len,
296-
max_num_reqs=self.max_num_reqs,
297295
) if self.supports_mm_inputs else None)
298296

299297
if not self.use_spmd:
@@ -1545,14 +1543,9 @@ def profile_run(
15451543
# NOTE: Currently model is profiled with a single non-text
15461544
# modality with the max possible input tokens even when
15471545
# it supports multiple.
1548-
(
1549-
dummy_modality,
1550-
max_tokens,
1551-
) = mm_budget.get_modality_with_max_tokens()
1552-
(
1553-
max_mm_items_per_prompt,
1554-
max_mm_items_per_batch,
1555-
) = mm_budget.get_max_items(dummy_modality, max_tokens)
1546+
dummy_modality = mm_budget.get_modality_with_max_tokens()
1547+
max_mm_items_per_batch = mm_budget \
1548+
.max_items_per_batch_by_modality[dummy_modality]
15561549

15571550
logger.info(
15581551
"Encoder cache will be initialized with a budget of "

vllm/v1/worker/utils.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from vllm.model_executor.models.utils import extract_layer_index
1313
from vllm.multimodal.registry import MultiModalRegistry
1414
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
15-
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
15+
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
1616
from vllm.v1.kv_cache_interface import KVCacheGroupSpec
1717

1818
if TYPE_CHECKING:
@@ -27,35 +27,32 @@ def __init__(
2727
model_config: ModelConfig,
2828
scheduler_config: SchedulerConfig,
2929
mm_registry: MultiModalRegistry,
30-
*,
31-
max_model_len: int,
32-
max_num_reqs: int,
3330
) -> None:
3431
super().__init__()
3532

3633
self.model_config = model_config
3734
self.scheduler_config = scheduler_config
3835
self.mm_registry = mm_registry
3936

40-
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
41-
model_config=model_config,
42-
scheduler_config=scheduler_config,
43-
mm_registry=mm_registry,
37+
self.max_model_len = model_config.max_model_len
38+
self.max_num_reqs = scheduler_config.max_num_seqs
39+
40+
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
41+
42+
max_tokens_by_modality = mm_registry \
43+
.get_max_tokens_per_item_by_nonzero_modality(model_config)
44+
45+
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
46+
scheduler_config,
47+
max_tokens_by_modality,
4448
)
4549

46-
self.max_num_encoder_input_tokens = encoder_compute_budget
50+
self.encoder_compute_budget = encoder_compute_budget
4751
self.encoder_cache_size = encoder_cache_size
48-
self.max_model_len = max_model_len
49-
self.max_num_reqs = max_num_reqs
50-
51-
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
5252

5353
max_items_per_prompt_by_modality = dict[str, int]()
5454
max_items_per_batch_by_modality = dict[str, int]()
5555

56-
max_tokens_by_modality = mm_registry \
57-
.get_max_tokens_per_item_by_nonzero_modality(model_config)
58-
5956
for modality, max_tokens in max_tokens_by_modality.items():
6057
(
6158
max_items_per_prompt,
@@ -69,15 +66,14 @@ def __init__(
6966
self.max_items_per_prompt_by_modality = max_items_per_prompt_by_modality
7067
self.max_items_per_batch_by_modality = max_items_per_batch_by_modality
7168

72-
def get_modality_with_max_tokens(self) -> tuple[str, int]:
69+
def get_modality_with_max_tokens(self) -> str:
7370
max_tokens_by_modality = self.max_tokens_by_modality
74-
modality, max_tokens = max(max_tokens_by_modality.items(),
75-
key=lambda item: item[1])
71+
modality, _ = max(max_tokens_by_modality.items(), key=lambda x: x[1])
7672

77-
return modality, max_tokens
73+
return modality
7874

7975
def get_encoder_budget(self) -> int:
80-
return min(self.max_num_encoder_input_tokens, self.encoder_cache_size)
76+
return min(self.encoder_compute_budget, self.encoder_cache_size)
8177

8278
def get_max_items(
8379
self,

0 commit comments

Comments
 (0)