Skip to content

Commit

Permalink
[Model] Adding support for MiniCPM-V (vllm-project#4087)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
HwwwwwwwH authored and Alvant committed Oct 26, 2024
1 parent a5c5247 commit 3b9cecd
Show file tree
Hide file tree
Showing 11 changed files with 942 additions and 18 deletions.
2 changes: 2 additions & 0 deletions docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ Registry
Base Classes
------------

.. autodata:: vllm.multimodal.NestedTensors

.. autodata:: vllm.multimodal.BatchedTensors

.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
Expand Down
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ Vision Language Models
- Phi-3-Vision
- :code:`microsoft/Phi-3-vision-128k-instruct`, etc.
-
* - :code:`MiniCPM-V`
- MiniCPM-V
- :code:`openbmb/MiniCPM-V-2`, :code:`openbmb/MiniCPM-Llama3-V-2_5`, etc.
-

If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
Expand Down
53 changes: 53 additions & 0 deletions examples/minicpmv_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset

# 2.0
# MODEL_NAME = "HwwwH/MiniCPM-V-2"
# 2.5
MODEL_NAME = "openbmb/MiniCPM-Llama3-V-2_5"

image = ImageAsset("stop_sign").pil_image.convert("RGB")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
llm = LLM(model=MODEL_NAME,
gpu_memory_utilization=1,
trust_remote_code=True,
max_model_len=4096)

messages = [{
'role':
'user',
'content':
'(<image>./</image>)\n' + "What's the content of the image?"
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
# 2.0
# stop_token_ids = [tokenizer.eos_id]
# 2.5
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]

sampling_params = SamplingParams(
stop_token_ids=stop_token_ids,
# temperature=0.7,
# top_p=0.8,
# top_k=100,
# seed=3472,
max_tokens=1024,
# min_tokens=150,
temperature=0,
use_beam_search=True,
# length_penalty=1.2,
best_of=3)

outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {
"image": image
}
},
sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
11 changes: 6 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoTokenizer, BatchEncoding)
AutoTokenizer, BatchEncoding, BatchFeature)

from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
Expand Down Expand Up @@ -133,7 +133,7 @@ def image_assets() -> _ImageAssets:
return IMAGE_ASSETS


_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)


class HfRunner:
Expand Down Expand Up @@ -339,7 +339,6 @@ def generate_greedy_logprobs_limit(
processor_kwargs["images"] = images[i]

inputs = self.processor(**processor_kwargs)
input_ids = inputs.input_ids

output = self.model.generate(
**self.wrap_device(inputs),
Expand Down Expand Up @@ -381,7 +380,7 @@ def generate_greedy_logprobs_limit(

all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_len = len(seq_logprobs_lst)
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))
Expand Down Expand Up @@ -514,10 +513,12 @@ def generate_greedy_logprobs(
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs)
logprobs=num_logprobs,
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images)
Expand Down
163 changes: 163 additions & 0 deletions tests/models/test_minicpmv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
from collections import UserDict
from typing import List, Optional, Tuple, Type

import pytest
import torch
import torch.types
from transformers import BatchFeature

from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs

from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close

pytestmark = pytest.mark.vlm

# The image token is placed before "user" on purpose so that the test can pass
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\nWhat's the content of the image?<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501
"cherry_blossom":
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
"<|start_header_id|>assistant<|end_header_id|>\n\n"
})

models = ["openbmb/MiniCPM-Llama3-V-2_5"]


def trunc_hf_output(hf_output: Tuple[List[int], str,
Optional[SampleLogprobs]]):
output_ids, output_str, out_logprobs = hf_output
if output_str.endswith("<|eot_id|>"):
output_str = output_str.split("<|eot_id|>")[0]
return output_ids, output_str, out_logprobs


target_dtype = "half"


def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
images = [asset.pil_image for asset in image_assets]

inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]

# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=4096,
max_num_seqs=1,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
tokenizer = vllm_model.model.get_tokenizer()
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=vllm_images,
stop_token_ids=stop_token_ids)
for prompts, vllm_images in inputs_per_image
]

with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():

class NestedInputs(UserDict):

def __init__(self, model_inputs: BatchFeature):
super().__init__({"model_inputs": model_inputs})

self.model_inputs = model_inputs

def to(self, device: torch.types.Device):
return NestedInputs(self.model_inputs.to(device))

hf_processor = hf_model.processor
hf_model.processor = lambda **kw: NestedInputs(
hf_processor(**kw) # type: ignore
)

hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=hf_images,
tokenizer=tokenizer)
for prompts, hf_images in inputs_per_image
]

for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
vllm_outputs_per_image):
check_logprobs_close(
outputs_0_lst=[
trunc_hf_output(hf_output) for hf_output in hf_outputs
],
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,11 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
input_embeds: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
input_embeds)
return model_output

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,11 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
input_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, input_embeds)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down
Loading

0 comments on commit 3b9cecd

Please sign in to comment.