Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Adding support for MiniCPM-V #4087

Merged
merged 70 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
caac38f
minicpm-v
HwwwwwwwH Apr 15, 2024
a619354
fix format
HwwwwwwwH Apr 15, 2024
6f7e0ef
add minicpmv example
HwwwwwwwH Apr 15, 2024
75c2d31
fix format
HwwwwwwwH Apr 15, 2024
4b4c7f3
add timm import hints
HwwwwwwwH Apr 23, 2024
8e70b68
Merge branch 'main' of github.com:HwwwwwwwH/vllm
HwwwwwwwH Apr 23, 2024
6c953a9
Merge branch 'main' into minicpmv
HwwwwwwwH Apr 23, 2024
189d28e
adapt to new vllm version
HwwwwwwwH Apr 23, 2024
af353f2
add timm dependency to requirements-common.txt, change examples/minicpmv
HwwwwwwwH Apr 26, 2024
1e88026
merge latest main
HwwwwwwwH May 5, 2024
4204a02
merge latest main
HwwwwwwwH May 5, 2024
8b63870
merge latest main
HwwwwwwwH May 5, 2024
cc64a0b
Merge branch 'main' into minicpmv
HwwwwwwwH May 6, 2024
0280936
Merge branch 'main' into minicpmv
HwwwwwwwH May 24, 2024
b01948c
minicpmv_2 init
HwwwwwwwH May 24, 2024
0b1be33
Make changes based on the review
May 27, 2024
93bbc4c
Make changes based on the review
May 27, 2024
a29df42
Merge branch 'main' into minicpmv_2
May 27, 2024
7724d0e
fix
HwwwwwwwH May 27, 2024
fe58513
fix:get model dtype from default_dtype
May 27, 2024
c9aacd8
delete redundant annotations
May 27, 2024
294a989
Merge branch 'main' into minicpmv
Jun 14, 2024
e90c326
add test for mnicpmv
Jun 17, 2024
965829c
Merge branch 'main' into minicpmv
Jun 17, 2024
51cf257
add minicpmv support in <get_full_image_text_prompt>
Jun 19, 2024
fc2dcaf
add test for minicpmv / fix bug in slice image / add minicpmv in get_…
Jun 19, 2024
81d4437
format for minicpmv
HwwwwwwwH Jun 19, 2024
938a741
format minicpmv
HwwwwwwwH Jun 19, 2024
ff8499d
format minicpmv
HwwwwwwwH Jun 19, 2024
123bdf0
changed for image processor
Jul 8, 2024
d9187c8
update processing minicpmv
Jul 10, 2024
da4c965
update processing minicpmv
Jul 10, 2024
833475b
merge new main 7.10.17:20
Jul 10, 2024
3cd25fa
merge new main 7.10.17:20
Jul 10, 2024
3cc2eb7
complete test minicpmv
Jul 11, 2024
976730f
update example of minicpmv
Jul 11, 2024
6152d8e
merge main
Jul 11, 2024
edb98b5
format
Jul 11, 2024
69629a8
add minicpmv2.5
Jul 19, 2024
223cc74
Merge branch 'main' into minicpmv
Jul 19, 2024
d9e01f9
update example
Jul 19, 2024
2d16bbc
update example of minicpmv
Jul 19, 2024
4e427a8
update examles
Jul 19, 2024
bcaf90c
format
Jul 19, 2024
c13e506
add test for minicpmv
Jul 19, 2024
a8274f0
add test for minicpmv
Jul 19, 2024
8de6dc8
modify for merge
Jul 19, 2024
71698d6
delete redundant hints; decrease logprobs in test minicpmv
Jul 19, 2024
cb6a581
fix modelname of test_minicpmv.py
Jul 22, 2024
a210080
adjust timm import
Jul 22, 2024
97e5747
format
Jul 22, 2024
4352143
Update minicpmv_example.py
HwwwwwwwH Jul 22, 2024
a3de850
Merge branch 'main' into minicpmv
Jul 23, 2024
355ab8c
fix cuda memory exceeded while test minicpmv
Jul 23, 2024
c6d8ea5
get version
Jul 23, 2024
e469f86
add version
Jul 23, 2024
8ba9a4b
Merge branch 'main' into minicpmv
Jul 23, 2024
eb46f77
fix test warnings
Jul 24, 2024
e8bfeac
Merge branch 'main' into minicpmv
Jul 24, 2024
dc174d8
add NestedTensors & final
Jul 24, 2024
f85d168
Fix type annotations
DarkLight1337 Jul 24, 2024
d79284a
Apply patches inside model file to avoid changing existing code
DarkLight1337 Jul 24, 2024
107379a
Whitespace
DarkLight1337 Jul 24, 2024
7d1f0ff
merge
Jul 24, 2024
e3790d1
Merge branch 'minicpmv' of github.com:HwwwwwwwH/vllm into minicpmv
Jul 24, 2024
512a5a5
Fix conversion of nested inputs
DarkLight1337 Jul 24, 2024
279abf8
lint
DarkLight1337 Jul 24, 2024
feae37a
Merge branch 'vllm-project:main' into minicpmv
HwwwwwwwH Jul 24, 2024
ca9c8e3
fix openai server
Jul 24, 2024
3523152
Merge branch 'minicpmv' of github.com:HwwwwwwwH/vllm into minicpmv
Jul 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Registry
Base Classes
------------

.. autodata:: vllm.multimodal.NestedTensors

.. autodata:: vllm.multimodal.BatchedTensors

.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ Vision Language Models
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
-
* - :code:`MiniCPM-V`
- MiniCPM-V
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
-

If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
Expand Down
53 changes: 53 additions & 0 deletions examples/minicpmv_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

# 2.0
# MODEL_NAME = "HwwwH/MiniCPM-V-2"
# 2.5
MODEL_NAME = "openbmb/MiniCPM-Llama3-V-2_5"

image = ImageAsset("stop_sign").pil_image.convert("RGB")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
llm = LLM(model=MODEL_NAME,
gpu_memory_utilization=1,
trust_remote_code=True,
max_model_len=4096)

messages = [{
'role':
'user',
'content':
'(<image>./</image>)\n' + "What's the content of the image?"
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# 2.0
# stop_token_ids = [tokenizer.eos_id]
# 2.5
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

sampling_params = SamplingParams(
stop_token_ids=stop_token_ids,
# temperature=0.7,
# top_p=0.8,
# top_k=100,
# seed=3472,
max_tokens=1024,
# min_tokens=150,
temperature=0,
use_beam_search=True,
# length_penalty=1.2,
best_of=3)

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": image
}
},
sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
11 changes: 6 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoTokenizer, BatchEncoding)
AutoTokenizer, BatchEncoding, BatchFeature)

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
Expand Down Expand Up @@ -133,7 +133,7 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS


_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)


class HfRunner:
Expand Down Expand Up @@ -339,7 +339,6 @@ def generate_greedy_logprobs_limit(
processor_kwargs["images"] = images[i]

inputs = self.processor(**processor_kwargs)
input_ids = inputs.input_ids

output = self.model.generate(
**self.wrap_device(inputs),
Expand Down Expand Up @@ -381,7 +380,7 @@ def generate_greedy_logprobs_limit(

all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_len = len(seq_logprobs_lst)
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
Expand Down Expand Up @@ -514,10 +513,12 @@ def generate_greedy_logprobs(
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs)
logprobs=num_logprobs,
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images)
Expand Down
163 changes: 163 additions & 0 deletions tests/models/test_minicpmv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from collections import UserDict
from typing import List, Optional, Tuple, Type

import pytest
import torch
import torch.types
from transformers import BatchFeature

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close

pytestmark = pytest.mark.vlm

# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\nWhat's the content of the image?<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
"cherry_blossom":
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n"
})

models = ["openbmb/MiniCPM-Llama3-V-2_5"]


def trunc_hf_output(hf_output: Tuple[List[int], str,
Optional[SampleLogprobs]]):
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<|eot_id|>"):
output_str = output_str.split("<|eot_id|>")[0]
return output_ids, output_str, out_logprobs


target_dtype = "half"


def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=vllm_images,
stop_token_ids=stop_token_ids)
for prompts, vllm_images in inputs_per_image
]

with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():

class NestedInputs(UserDict):

def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})

self.model_inputs = model_inputs

def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))

hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)

hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=hf_images,
tokenizer=tokenizer)
for prompts, hf_images in inputs_per_image
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=[
trunc_hf_output(hf_output) for hf_output in hf_outputs
],
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,11 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
input_embeds: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
input_embeds)
return model_output

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,11 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
input_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, input_embeds)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down
Loading
Loading