Skip to content

Commit 4c5f632

Browse files
[Misc] Simplify max tokens in multimodal registry (#27500)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
1 parent b853540 commit 4c5f632

File tree

4 files changed

+17
-38
lines changed

4 files changed

+17
-38
lines changed

vllm/multimodal/profiling.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,11 @@ def _get_mm_max_tokens(
355355
mm_counts=mm_counts,
356356
)
357357
if max_tokens_per_item is not None:
358-
return max_tokens_per_item
358+
return {
359+
modality: max_tokens
360+
for modality, max_tokens in max_tokens_per_item.items()
361+
if mm_counts.get(modality, 0) > 0
362+
}
359363

360364
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
361365
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
@@ -375,5 +379,4 @@ def get_mm_max_contiguous_tokens(
375379
This is important to take into account when profiling and
376380
initializing the encoder cache size.
377381
"""
378-
379382
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)

vllm/multimodal/registry.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def get_max_tokens_per_item_by_modality(
152152
model_config: "ModelConfig",
153153
*,
154154
cache: BaseMultiModalProcessorCache | None = None,
155+
profiler_limits: Mapping[str, int] | None = None,
155156
) -> Mapping[str, int]:
156157
"""
157158
Get the maximum number of tokens per data item from each modality based
@@ -164,40 +165,15 @@ def get_max_tokens_per_item_by_modality(
164165
profiler: MultiModalProfiler = MultiModalProfiler(processor)
165166

166167
seq_len = model_config.max_model_len
167-
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
168+
profiler_limits = (
169+
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
170+
)
168171

169172
return profiler.get_mm_max_contiguous_tokens(
170173
seq_len,
171-
{modality: 1 for modality, limit in mm_limits.items() if limit > 0},
172-
)
173-
174-
def get_max_tokens_per_item_by_nonzero_modality(
175-
self,
176-
model_config: "ModelConfig",
177-
*,
178-
cache: BaseMultiModalProcessorCache | None = None,
179-
) -> Mapping[str, int]:
180-
"""
181-
Get the maximum number of tokens per data item from each modality based
182-
on underlying model configuration, excluding modalities that user
183-
explicitly disabled via `limit_mm_per_prompt`.
184-
185-
Note:
186-
This is currently directly used only in V1 for profiling the memory
187-
usage of a model.
188-
"""
189-
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
190-
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
191-
model_config,
192-
cache=cache,
174+
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
193175
)
194176

195-
return {
196-
key: max_tokens_per_mm_item
197-
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
198-
if mm_limits[key] > 0
199-
}
200-
201177
def get_mm_limits_per_prompt(
202178
self,
203179
model_config: "ModelConfig",
@@ -369,7 +345,7 @@ def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
369345
"""
370346
if not model_config.is_encoder_decoder:
371347
return 0
372-
max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config)
348+
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
373349
if not max_tokens:
374350
# TODO - this function assumes encoder-decoder models are
375351
# multimodal. This will need to change when adding support for more

vllm/v1/core/encoder_cache_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ def compute_encoder_budget(
264264
from the input sequence.
265265
"""
266266
if mm_registry.supports_multimodal_inputs(model_config):
267-
max_tokens_by_modality = (
268-
mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config)
267+
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
268+
model_config
269269
)
270270

271271
return compute_mm_encoder_budget(

vllm/v1/worker/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def __init__(
4242

4343
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)
4444

45-
max_tokens_by_modality = (
46-
mm_registry.get_max_tokens_per_item_by_nonzero_modality(
47-
model_config, cache=cache
48-
)
45+
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
46+
model_config,
47+
cache=cache,
48+
profiler_limits=self.mm_limits,
4949
)
5050

5151
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(

0 commit comments

Comments
 (0)