Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/multimodal/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,11 @@ def _get_mm_max_tokens(
mm_counts=mm_counts,
)
if max_tokens_per_item is not None:
return max_tokens_per_item
return {
modality: max_tokens
for modality, max_tokens in max_tokens_per_item.items()
if mm_counts.get(modality, 0) > 0
}

mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
Expand All @@ -375,5 +379,4 @@ def get_mm_max_contiguous_tokens(
This is important to take into account when profiling and
initializing the encoder cache size.
"""

return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
36 changes: 6 additions & 30 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def get_max_tokens_per_item_by_modality(
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
profiler_limits: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
Expand All @@ -164,40 +165,15 @@ def get_max_tokens_per_item_by_modality(
profiler: MultiModalProfiler = MultiModalProfiler(processor)

seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
profiler_limits = (
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)

return profiler.get_mm_max_contiguous_tokens(
seq_len,
{modality: 1 for modality, limit in mm_limits.items() if limit > 0},
)

def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
*,
cache: BaseMultiModalProcessorCache | None = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
on underlying model configuration, excluding modalities that user
explicitly disabled via `limit_mm_per_prompt`.

Note:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)

return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
if mm_limits[key] > 0
}

def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
Expand Down Expand Up @@ -369,7 +345,7 @@ def get_encdec_max_encoder_len(self, model_config: "ModelConfig") -> int:
"""
if not model_config.is_encoder_decoder:
return 0
max_tokens = self.get_max_tokens_per_item_by_nonzero_modality(model_config)
max_tokens = self.get_max_tokens_per_item_by_modality(model_config)
if not max_tokens:
# TODO - this function assumes encoder-decoder models are
# multimodal. This will need to change when adding support for more
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/core/encoder_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def compute_encoder_budget(
from the input sequence.
"""
if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_nonzero_modality(model_config)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config
)

return compute_mm_encoder_budget(
Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def __init__(

self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, cache=cache)

max_tokens_by_modality = (
mm_registry.get_max_tokens_per_item_by_nonzero_modality(
model_config, cache=cache
)
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
profiler_limits=self.mm_limits,
)

encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
Expand Down