Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Calculate maximum number of multi-modal tokens by model #6121

Merged
merged 8 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 48 additions & 20 deletions docs/source/dev/multimodal/adding_multimodal_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,40 +51,68 @@ As usual, follow :ref:`these steps <adding_a_new_model>` to implement the model
2. Register input mappers
-------------------------

For each modality type to support, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
For each modality type that the model accepts as input, decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_input_mapper <vllm.multimodal.MultiModalRegistry.register_input_mapper>`.
This decorator accepts a function that maps multi-modal inputs to the keyword arguments you have previously defined in :meth:`~torch.nn.Module.forward`.

.. code-block:: diff

from vllm.model_executor.models.interfaces import SupportsVision
from vllm.model_executor.models.interfaces import SupportsVision
+ from vllm.multimodal import MULTIMODAL_REGISTRY

+ @MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
+ @MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):
+ @MULTIMODAL_REGISTRY.register_image_input_mapper()
class YourModelForImage2Seq(nn.Module, SupportsVision):

A default mapper is available for each modality in the core vLLM library. This input mapper will be used if you do not provide your own function.

.. seealso::
:ref:`input_processing_pipeline`


3. (Optional) Register dummy data
3. Register maximum number of multimodal tokens
----------------------------------------------------------

For each modality type that the model accepts as input, calculate the maximum possible number of tokens
and register it via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_max_multimodal_tokens>`.

.. code-block:: diff

from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY

@MULTIMODAL_REGISTRY.register_image_input_mapper()
+ @MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the relationship between register_max_tokens and register_dummy_data is a bit intricate. There needs to be certain level of consistency here. Hard to get right. Should we mention something here?

Copy link
Member Author

@DarkLight1337 DarkLight1337 Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I currently have a note in registry_dummy_data that mentions it should use the max number of tokens from each modality. Is that sufficient?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the two should be tied together for consistency - see my comment below in phi3v.py.

@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):

Here are some examples:

- Image inputs (static feature size): `LLaVA-1.5 Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava.py>`__
- Image inputs (dynamic feature size): `LLaVA-NeXT Model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llava_next.py>`__

.. seealso::
:ref:`input_processing_pipeline`


4. (Optional) Register dummy data
---------------------------------

During startup, dummy data is passed to the vLLM model to allocate memory. This only consists of text input by default, which may not be applicable to multi-modal models.
In such cases, you can define your own dummy data by registering a factory method via :meth:`INPUT_REGISTRY.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`.

.. code-block:: diff

from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY

@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
+ @INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsVision):

.. note::
The dummy data should have the maximum possible number of multi-modal tokens, as described in the previous step.

Here are some examples:

Expand All @@ -95,7 +123,7 @@ Here are some examples:
:ref:`input_processing_pipeline`


4. (Optional) Register input processor
5. (Optional) Register input processor
--------------------------------------

Sometimes, there is a need to process inputs at the :class:`~vllm.LLMEngine` level before they are passed to the model executor.
Expand All @@ -104,15 +132,15 @@ You can register input processors via :meth:`INPUT_REGISTRY.register_input_proce

.. code-block:: diff

from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.models.interfaces import SupportsVision
from vllm.multimodal import MULTIMODAL_REGISTRY

@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper()
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(<your_calculation>)
@INPUT_REGISTRY.register_dummy_data(<your_dummy_data_factory>)
+ @INPUT_REGISTRY.register_input_processor(<your_input_processor>)
class YourModelForImage2Seq(nn.Module, SupportsVision):
class YourModelForImage2Seq(nn.Module, SupportsVision):

A common use case of input processors is inserting placeholder tokens to leverage the vLLM framework for attention mask generation.
Here are some examples:
Expand Down
18 changes: 4 additions & 14 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,8 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``

.. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.

This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.


To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
Expand Down Expand Up @@ -104,13 +99,8 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with

.. important::
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified, and internally we will construct data structures for
every model to perform profiling with.

This work is still ongoing. In the meantime, we internally hardcode ``image_feature_size = 3000`` through
:meth:`MULTIMODAL_REGISTRY.get_num_input_tokens <vllm.multimodal.MultiModalRegistry.get_num_input_tokens>`
for every model to be conservative in terms of GPU memory consumption. This hardcoded value will be replaced
with a more accurate profiling strategy in the future.
the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that
internally for each model.

To consume the server, you can use the OpenAI client like in the example below:

Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
patch_size=hf_config.patch_size)


def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
return get_clip_image_feature_size(hf_config)


def dummy_seq_data_for_clip(
hf_config: CLIPVisionConfig,
seq_len: int,
Expand Down
14 changes: 13 additions & 1 deletion vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from vllm.sequence import IntermediateTensors, SamplerOutput

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
get_max_clip_image_tokens, input_processor_for_clip)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings

Expand Down Expand Up @@ -62,6 +62,17 @@ class LlavaImagePixelInputs(TypedDict):
LlavaImageInputs = LlavaImagePixelInputs


def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config

if isinstance(vision_config, CLIPVisionConfig):
return get_max_clip_image_tokens(vision_config)

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)


def dummy_data_for_llava(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
Expand Down Expand Up @@ -102,6 +113,7 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,17 @@ def get_llava_next_image_feature_size(
raise NotImplementedError(msg)


def get_max_llava_next_image_tokens(ctx: InputContext):
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
dummy_height = dummy_width = 448

return get_llava_next_image_feature_size(
ctx.get_hf_config(LlavaNextConfig),
input_height=dummy_height,
input_width=dummy_width,
)


def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
Expand Down Expand Up @@ -198,6 +209,7 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
Expand Down
12 changes: 12 additions & 0 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,17 @@ def get_phi3v_image_feature_size(
+ (new_height // 336 + 1) * 12


def get_max_phi3v_image_tokens(ctx: InputContext):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_max_phi3v_image_tokens and dummy_data_for_phi3v are both based on dummy_height, dummy_width = 8000, 50, so we should make these constants to this file for consistency. I think this will suffice for the purpose of consistency for now, and in the future we can establish more structured protocol between multimodal feature size and dummy data.

Copy link
Member

@ywang96 ywang96 Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made #6146 to address this.

# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50

return get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
input_height=dummy_height,
input_width=dummy_width,
)


def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50
Expand Down Expand Up @@ -429,6 +440,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):


@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
Expand Down
65 changes: 65 additions & 0 deletions vllm/multimodal/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class MultiModalDataBuiltins(TypedDict, total=False):
:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""

MultiModalTokensCalc = Union[int, Callable[[InputContext], int]]
"""Calculate the maximum number of multimodal tokens input to the language
model. This does not include the tokens that correspond to the input text."""

N = TypeVar("N", bound=Type[nn.Module])


Expand All @@ -117,6 +121,7 @@ class MultiModalPlugin(ABC):

def __init__(self) -> None:
self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {}
self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {}

@abstractmethod
def get_data_key(self) -> str:
Expand Down Expand Up @@ -188,3 +193,63 @@ def map_input(self, model_config: ModelConfig,
f"model class {model_cls.__name__}.")

return mapper(InputContext(model_config), data)

def register_max_multimodal_tokens(
self,
max_mm_tokens: MultiModalTokensCalc,
):
"""
Register the maximum number of multi-modal tokens input to the
language model for a model class.

See also:
:ref:`adding_a_new_multimodal_model`
"""

def wrapper(model_cls: N) -> N:
if model_cls in self._input_mappers:
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
"Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.",
model_cls, self)

self._max_mm_tokens[model_cls] = max_mm_tokens

return model_cls

return wrapper

def get_max_multimodal_tokens(
self,
model_config: ModelConfig,
default: int,
):
"""
Get the maximum number of multi-modal tokens
for profiling the memory usage of a model.

If this registry is not applicable to the model,
instead return the ``default`` value.

The model is identified by ``model_config``.

See also:
:ref:`adding_a_new_multimodal_model`
"""
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture

model_cls, _ = get_model_architecture(model_config)

if model_cls not in self._input_mappers:
return default

max_mm_tokens = self._max_mm_tokens.get(model_cls)
if max_mm_tokens is None:
raise KeyError(f"No maximum number of multi-modal tokens is given "
f"for model class {model_cls.__name__} in {self}.")

if isinstance(max_mm_tokens, int):
return max_mm_tokens

return max_mm_tokens(InputContext(model_config))
Loading
Loading