Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ def profile_run(self) -> None:

# For profile, have maximum num_reqs and that collectively have
# maximum num_tokens.
num_reqs = self.scheduler_config.max_num_seqs
num_reqs = self.max_num_reqs
num_tokens = self.max_num_tokens
min_tokens_per_req = num_tokens // num_reqs

Expand Down
101 changes: 101 additions & 0 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
import gc
import time
from typing import TYPE_CHECKING, Optional, cast
from unittest.mock import patch
Expand Down Expand Up @@ -737,6 +738,106 @@ def _dummy_run(
inputs_embeds=inputs_embeds,
)

def profile_run(
self,
kv_caches,
num_tokens: int,
) -> None:
# Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them.
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
and self.encoder_cache_size > 0):

# NOTE: Currently model is profiled with a single non-text
# modality with the max possible input tokens even when
# it supports multiple.
max_tokens_by_modality_dict = (
MULTIMODAL_REGISTRY.
get_max_tokens_per_item_by_nonzero_modality(self.model_config))
dummy_data_modality, max_tokens_per_mm_item = max(
max_tokens_by_modality_dict.items(), key=lambda item: item[1])

# Check how many items of this modality can be supported by
# the encoder budget.
encoder_budget = min(self.max_num_encoder_input_tokens,
self.encoder_cache_size)

# max_num_mm_items_encoder_budget = cdiv(encoder_budget,
# max_tokens_per_mm_item)

# # Check how many items of this modality can be supported by
# # the decoder budget.
# max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
# self.model_config)[dummy_data_modality]

# # NOTE: We do not consider max_num_batched_tokens on purpose
# # because the multimodal embeddings can be generated in advance
# # and chunked prefilled.
# max_num_mm_items_decoder_budget = self.max_num_reqs * \
# max_mm_items_per_req

# max_num_mm_items = min(max_num_mm_items_encoder_budget,
# max_num_mm_items_decoder_budget)
# TODO(mgoin): Support batching when new kernel lands
max_num_mm_items = 1

logger.info(
"Encoder cache will be initialized with a budget of %s tokens,"
" and profiled with %s %s items of the maximum feature size.",
encoder_budget, max_num_mm_items, dummy_data_modality)

# Create dummy batch of multimodal inputs.
dummy_request_data = self.input_registry.dummy_data_for_profiling(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_registry=self.mm_registry,
)
dummy_mm_data = dummy_request_data.multi_modal_data

# Dummy data definition in V0 may contain multiple multimodal items
# (e.g, multiple images) for a single request, therefore here we
# always replicate first item by max_num_mm_items times since in V1
# they are scheduled to be processed separately.

assert isinstance(dummy_mm_data, MultiModalKwargs), (
"Expected dummy multimodal data to be of type "
f"MultiModalKwargs, got {type(dummy_mm_data)=} instead. "
"This is most likely due to the model not having a merged "
"processor.")

# When models have a merged processor, their dummy data is
# already batched `MultiModalKwargs`, therefore we take the first
# `MultiModalKwargsItem` from the desired modality to profile on.
dummy_mm_item = dummy_mm_data.get_item(
modality=dummy_data_modality, item_index=0)
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])

batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)

# Run multimodal encoder.
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
**batched_dummy_mm_inputs)
assert len(dummy_encoder_outputs) == max_num_mm_items, (
"Expected dimension 0 of encoder outputs to match the number "
f"of multimodal data items: {max_num_mm_items}, got "
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
"due to the 'get_multimodal_embeddings' method of the model "
"not implemented correctly.")

# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))

# Trigger compilation for general shape.
self._dummy_run(kv_caches, num_tokens)

xm.mark_step()
xm.wait_device_ops()
self.encoder_cache.clear()
gc.collect()

def capture_model(self) -> None:
"""Compile the model."""

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def determine_available_memory(self) -> int:
self.vllm_config.compilation_config.static_forward_context,
runner_kv_caches)

self.model_runner._dummy_run(
self.model_runner.profile_run(
runner_kv_caches,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
Expand Down