Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Remove image_input_type from VLM config #5852

Merged
merged 23 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .buildkite/download-images.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ set -o pipefail
# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
mkdir -p images
cd images
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_pixel_values.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom_image_features.pt
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg

Expand Down
16 changes: 4 additions & 12 deletions docs/requirements-docs.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,5 @@
sphinx == 6.2.1
sphinx-book-theme == 1.0.1
sphinx-copybutton == 0.5.2
myst-parser == 2.0.0
sphinx==6.2.1
sphinx-book-theme==1.0.1
sphinx-copybutton==0.5.2
myst-parser==2.0.0
sphinx-argparse

# packages to install to build the documentation
pydantic
-f https://download.pytorch.org/whl/cpu
torch
py-cpuinfo
transformers
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
8 changes: 5 additions & 3 deletions docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ vLLM provides experimental support for multi-modal models through the :mod:`vllm
which allows you to pass in multi-modal input alongside text and token prompts.

By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model,
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.
you must decorate the model class with :meth:`InputRegistry.register_dummy_data <vllm.inputs.registry.InputRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input_mapper <MultiModalRegistry.register_input_mapper>` for each modality type to support.

# TODO: Add more instructions on how to do that once embeddings is in.

Module Contents
+++++++++++++++
Expand All @@ -29,7 +31,7 @@ Registry
Base Classes
------------

.. autoclass:: vllm.multimodal.MultiModalData
.. autoclass:: vllm.multimodal.MultiModalDataDict
:members:
:show-inheritance:

Expand Down
11 changes: 7 additions & 4 deletions docs/source/models/vlm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``

llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
Expand All @@ -49,7 +48,12 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:

* ``prompt``: The prompt should have a number of ``<image>`` tokens equal to ``image_feature_size``.
* ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

.. note::

``multi_modal_data`` can accept keys and values beyond the builtin ones, as long as a customized plugin is registered through
:class:`vllm.multimodal.MULTIMODAL_REGISTRY`.

.. code-block:: python

Expand All @@ -61,7 +65,7 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptS

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
"multi_modal_data": {"image": image},
})

for o in outputs:
Expand Down Expand Up @@ -93,7 +97,6 @@ Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with

python -m vllm.entrypoints.openai.api_server \
--model llava-hf/llava-1.5-7b-hf \
--image-input-type pixel_values \
--image-token-id 32000 \
--image-input-shape 1,3,336,336 \
--image-feature-size 576 \
Expand Down
56 changes: 8 additions & 48 deletions examples/llava_example.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,44 @@
import argparse
import os
import subprocess

import torch
from PIL import Image

from vllm import LLM
from vllm.multimodal.image import ImageFeatureData, ImagePixelData

# The assets are located at `s3://air-example-data-2/vllm_opensource_llava/`.
# You can use `.buildkite/download-images.sh` to download them


def run_llava_pixel_values(*, disable_image_processor: bool = False):
def run_llava():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="pixel_values",
image_token_id=32000,
image_input_shape="1,3,336,336",
image_feature_size=576,
disable_image_processor=disable_image_processor,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")

if disable_image_processor:
image = torch.load("images/stop_sign_pixel_values.pt")
else:
image = Image.open("images/stop_sign.jpg")
image = Image.open("images/stop_sign.jpg")

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
"multi_modal_data": {
"image": image
},
})

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def run_llava_image_features():
llm = LLM(
model="llava-hf/llava-1.5-7b-hf",
image_input_type="image_features",
image_token_id=32000,
image_input_shape="1,576,1024",
image_feature_size=576,
)

prompt = "<image>" * 576 + (
"\nUSER: What is the content of this image?\nASSISTANT:")

image: torch.Tensor = torch.load("images/stop_sign_image_features.pt")

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": ImageFeatureData(image),
})

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


def main(args):
if args.type == "pixel_values":
run_llava_pixel_values()
else:
run_llava_image_features()
def main():
run_llava()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Demo on Llava")
parser.add_argument("--type",
type=str,
choices=["pixel_values", "image_features"],
default="pixel_values",
help="image input type")
args = parser.parse_args()
# Download from s3
s3_bucket_path = "s3://air-example-data-2/vllm_opensource_llava/"
local_directory = "images"
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -95,4 +55,4 @@ def main(args):
local_directory,
"--no-sign-request",
])
main(args)
main()
1 change: 0 additions & 1 deletion examples/openai_vision_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Launch the vLLM server with the following command:
python -m vllm.entrypoints.openai.api_server \
--model llava-hf/llava-1.5-7b-hf \
--image-input-type pixel_values \
--image-token-id 32000 \
--image-input-shape 1,3,336,336 \
--image-feature-size 576 \
Expand Down
6 changes: 3 additions & 3 deletions examples/phi3v_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from PIL import Image

from vllm import LLM, SamplingParams
from vllm.multimodal.image import ImagePixelData


def run_phi3v():
Expand All @@ -17,7 +16,6 @@ def run_phi3v():
llm = LLM(
model=model_path,
trust_remote_code=True,
image_input_type="pixel_values",
image_token_id=32044,
image_input_shape="1,3,1008,1344",
image_feature_size=1921,
Expand All @@ -35,7 +33,9 @@ def run_phi3v():
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": ImagePixelData(image),
"multi_modal_data": {
"image": image
},
},
sampling_params=sampling_params)
for o in outputs:
Expand Down
38 changes: 8 additions & 30 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,17 @@
AutoTokenizer, BatchEncoding)

from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.config import TokenizerPoolConfig
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu

if TYPE_CHECKING:
from vllm.multimodal import MultiModalData
else:
# it will call torch.cuda.device_count()
MultiModalData = None
from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu
from vllm.multimodal import MultiModalDataDict

logger = init_logger(__name__)

Expand All @@ -51,35 +49,15 @@ def _read_prompts(filename: str) -> List[str]:
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"]

@cached_property
def pixel_values(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_pixel_values.pt")

@cached_property
def image_features(self) -> torch.Tensor:
return torch.load(_IMAGE_DIR / f"{self.name}_image_features.pt")

@cached_property
def pil_image(self) -> Image.Image:
return Image.open(_IMAGE_DIR / f"{self.name}.jpg")

def for_hf(self) -> Image.Image:
xwjiang2010 marked this conversation as resolved.
Show resolved Hide resolved
return self.pil_image

def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from vllm.multimodal.image import ImageFeatureData # noqa: F401
from vllm.multimodal.image import ImagePixelData
image_input_type = vision_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType

if image_input_type == ImageInputType.IMAGE_FEATURES:
return ImageFeatureData(self.image_features)
if image_input_type == ImageInputType.PIXEL_VALUES:
return ImagePixelData(self.pil_image)

raise NotImplementedError
def for_vllm(self) -> Dict[str, Any]:
return {"image": self.pil_image}


class _ImageAssetPrompts(TypedDict):
Expand Down Expand Up @@ -453,7 +431,7 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[List[MultiModalData]] = None,
images: Optional[List["MultiModalDataDict"]] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == len(images)
Expand Down Expand Up @@ -502,7 +480,7 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[List[MultiModalData]] = None,
images: Optional[List["MultiModalDataDict"]] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params, images=images)
Expand Down
2 changes: 0 additions & 2 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ def server():
"--max-model-len",
"4096",
"--enforce-eager",
"--image-input-type",
"pixel_values",
"--image-token-id",
"32000",
"--image-input-shape",
Expand Down
22 changes: 8 additions & 14 deletions tests/models/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,11 @@ def iter_llava_configs(model_name: str):
}

for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
(VisionLanguageConfig.ImageInputType.IMAGE_FEATURES, (1, f, 1024)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))
input_shape = (1, 3, h, w)
yield (model_name,
VisionLanguageConfig(image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape))


model_and_vl_config = [
Expand Down Expand Up @@ -81,8 +75,8 @@ def run_test(

All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
Expand All @@ -104,7 +98,7 @@ def run_test(
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
vllm_images = [asset.for_vllm() for asset in image_assets]

vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
Expand Down
23 changes: 10 additions & 13 deletions tests/models/test_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,13 @@ def iter_llava_next_configs(model_name: str):
}

for (h, w), f in image_hw_to_feature_size.items():
for input_type, input_shape in [
(VisionLanguageConfig.ImageInputType.PIXEL_VALUES, (1, 3, h, w)),
]:
yield (model_name,
VisionLanguageConfig(image_input_type=input_type,
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
image_processor=model_name,
image_processor_revision=None))
input_shape = (1, 3, h, w)
yield (model_name,
VisionLanguageConfig(
image_feature_size=f,
image_token_id=32000,
image_input_shape=input_shape,
))


model_and_vl_config = [
Expand Down Expand Up @@ -85,14 +82,14 @@ def test_models(hf_runner, vllm_runner, image_assets, model_and_config,

All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalData objects and corresponding
vision language config as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
vllm_images = [asset.for_vllm() for asset in image_assets]

with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
Expand Down
Loading
Loading