Skip to content

Commit

Permalink
[VLM] Calculate maximum number of multi-modal tokens by model (vllm-p…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and jimpang committed Jul 8, 2024
1 parent 795dd5d commit 0a33b42
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 90 deletions.
68 changes: 48 additions & 20 deletions docs/source/dev/multimodal/adding_multimodal_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,40 +51,68 @@ As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model
2. Register input mappers
-------------------------

For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
For each modality type that the model accepts as input, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.

.. code-block:: diff
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.multimodal import MULTIMODAL_REGISTRY
+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.

.. seealso::
:ref:`input_processing_pipeline`


3. (Optional) Register dummy data
3. Register maximum number of multimodal tokens
----------------------------------------------------------

For each modality type that the model accepts as input, calculate the maximum possible number of tokens
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.

.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
Here are some examples:

- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__

.. seealso::
:ref:`input_processing_pipeline`


4. (Optional) Register dummy data
---------------------------------

During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.

.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsVision):
.. note::
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.

Here are some examples:

Expand All @@ -95,7 +123,7 @@ Here are some examples:
:ref:`input_processing_pipeline`


4. (Optional) Register input processor
5. (Optional) Register input processor
--------------------------------------

Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
Expand All @@ -104,15 +132,15 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce

.. code-block:: diff
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsVision):
A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:
Expand Down
18 changes: 4 additions & 14 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,8 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
.. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.

This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.


To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
Expand Down Expand Up @@ -104,13 +99,8 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with
.. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.

This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.

To consume the server, you can use the OpenAI client like in the example below:

Expand Down
2 changes: 1 addition & 1 deletion vllm/inputs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_hf_config(self, hf_config_type: Type[C]) -> C:
additionally checking its type.
Raises:
ValueError: If the model is not of the specified type.
TypeError: If the model is not of the specified type.
"""

hf_config = self.model_config.hf_config
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
patch_size=hf_config.patch_size)


def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config)


def dummy_seq_data_for_clip(
hf_config: CLIPVisionConfig,
seq_len: int,
Expand Down
14 changes: 13 additions & 1 deletion vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.sequence import IntermediateTensors, SamplerOutput

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
get_max_clip_image_tokens, input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings

Expand Down Expand Up @@ -62,6 +62,17 @@ class LlavaImagePixelInputs(TypedDict):
LlavaImageInputs = LlavaImagePixelInputs


def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config

if isinstance(vision_config, CLIPVisionConfig):
return get_max_clip_image_tokens(vision_config)

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)


def dummy_data_for_llava(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
Expand Down Expand Up @@ -102,6 +113,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ def get_llava_next_image_feature_size(
raise NotImplementedError(msg)


def get_max_llava_next_image_tokens(ctx: InputContext):
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
dummy_height = dummy_width = 448

return get_llava_next_image_feature_size(
ctx.get_hf_config(LlavaNextConfig),
input_height=dummy_height,
input_width=dummy_width,
)


def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
Expand Down Expand Up @@ -198,6 +209,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@ def get_phi3v_image_feature_size(
+ (new_height // 336 + 1) * 12


def get_max_phi3v_image_tokens(ctx: InputContext):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50

return get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
input_height=dummy_height,
input_width=dummy_width,
)


def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50
Expand Down Expand Up @@ -429,6 +440,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
Expand Down
Loading

0 comments on commit 0a33b42

Please sign in to comment.