Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Adding support for MiniCPM-V #4087

Merged
merged 70 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
caac38f
minicpm-v
HwwwwwwwH Apr 15, 2024
a619354
fix format
HwwwwwwwH Apr 15, 2024
6f7e0ef
add minicpmv example
HwwwwwwwH Apr 15, 2024
75c2d31
fix format
HwwwwwwwH Apr 15, 2024
4b4c7f3
add timm import hints
HwwwwwwwH Apr 23, 2024
8e70b68
Merge branch 'main' of github.com:HwwwwwwwH/vllm
HwwwwwwwH Apr 23, 2024
6c953a9
Merge branch 'main' into minicpmv
HwwwwwwwH Apr 23, 2024
189d28e
adapt to new vllm version
HwwwwwwwH Apr 23, 2024
af353f2
add timm dependency to requirements-common.txt, change examples/minicpmv
HwwwwwwwH Apr 26, 2024
1e88026
merge latest main
HwwwwwwwH May 5, 2024
4204a02
merge latest main
HwwwwwwwH May 5, 2024
8b63870
merge latest main
HwwwwwwwH May 5, 2024
cc64a0b
Merge branch 'main' into minicpmv
HwwwwwwwH May 6, 2024
0280936
Merge branch 'main' into minicpmv
HwwwwwwwH May 24, 2024
b01948c
minicpmv_2 init
HwwwwwwwH May 24, 2024
0b1be33
Make changes based on the review
May 27, 2024
93bbc4c
Make changes based on the review
May 27, 2024
a29df42
Merge branch 'main' into minicpmv_2
May 27, 2024
7724d0e
fix
HwwwwwwwH May 27, 2024
fe58513
fix:get model dtype from default_dtype
May 27, 2024
c9aacd8
delete redundant annotations
May 27, 2024
294a989
Merge branch 'main' into minicpmv
Jun 14, 2024
e90c326
add test for mnicpmv
Jun 17, 2024
965829c
Merge branch 'main' into minicpmv
Jun 17, 2024
51cf257
add minicpmv support in <get_full_image_text_prompt>
Jun 19, 2024
fc2dcaf
add test for minicpmv / fix bug in slice image / add minicpmv in get_…
Jun 19, 2024
81d4437
format for minicpmv
HwwwwwwwH Jun 19, 2024
938a741
format minicpmv
HwwwwwwwH Jun 19, 2024
ff8499d
format minicpmv
HwwwwwwwH Jun 19, 2024
123bdf0
changed for image processor
Jul 8, 2024
d9187c8
update processing minicpmv
Jul 10, 2024
da4c965
update processing minicpmv
Jul 10, 2024
833475b
merge new main 7.10.17:20
Jul 10, 2024
3cd25fa
merge new main 7.10.17:20
Jul 10, 2024
3cc2eb7
complete test minicpmv
Jul 11, 2024
976730f
update example of minicpmv
Jul 11, 2024
6152d8e
merge main
Jul 11, 2024
edb98b5
format
Jul 11, 2024
69629a8
add minicpmv2.5
Jul 19, 2024
223cc74
Merge branch 'main' into minicpmv
Jul 19, 2024
d9e01f9
update example
Jul 19, 2024
2d16bbc
update example of minicpmv
Jul 19, 2024
4e427a8
update examles
Jul 19, 2024
bcaf90c
format
Jul 19, 2024
c13e506
add test for minicpmv
Jul 19, 2024
a8274f0
add test for minicpmv
Jul 19, 2024
8de6dc8
modify for merge
Jul 19, 2024
71698d6
delete redundant hints; decrease logprobs in test minicpmv
Jul 19, 2024
cb6a581
fix modelname of test_minicpmv.py
Jul 22, 2024
a210080
adjust timm import
Jul 22, 2024
97e5747
format
Jul 22, 2024
4352143
Update minicpmv_example.py
HwwwwwwwH Jul 22, 2024
a3de850
Merge branch 'main' into minicpmv
Jul 23, 2024
355ab8c
fix cuda memory exceeded while test minicpmv
Jul 23, 2024
c6d8ea5
get version
Jul 23, 2024
e469f86
add version
Jul 23, 2024
8ba9a4b
Merge branch 'main' into minicpmv
Jul 23, 2024
eb46f77
fix test warnings
Jul 24, 2024
e8bfeac
Merge branch 'main' into minicpmv
Jul 24, 2024
dc174d8
add NestedTensors & final
Jul 24, 2024
f85d168
Fix type annotations
DarkLight1337 Jul 24, 2024
d79284a
Apply patches inside model file to avoid changing existing code
DarkLight1337 Jul 24, 2024
107379a
Whitespace
DarkLight1337 Jul 24, 2024
7d1f0ff
merge
Jul 24, 2024
e3790d1
Merge branch 'minicpmv' of github.com:HwwwwwwwH/vllm into minicpmv
Jul 24, 2024
512a5a5
Fix conversion of nested inputs
DarkLight1337 Jul 24, 2024
279abf8
lint
DarkLight1337 Jul 24, 2024
feae37a
Merge branch 'vllm-project:main' into minicpmv
HwwwwwwwH Jul 24, 2024
ca9c8e3
fix openai server
Jul 24, 2024
3523152
Merge branch 'minicpmv' of github.com:HwwwwwwwH/vllm into minicpmv
Jul 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions examples/minicpmv_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import math

import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoConfig, AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.sequence import MultiModalData


def slice_image(image,
max_slice_nums=9,
scale_resolution=448,
patch_size=14,
never_split=False):
original_size = image.size
original_width, original_height = original_size
log_ratio = math.log(original_width / original_height)
ratio = original_width * original_height / (scale_resolution *
scale_resolution)
multiple = min(math.ceil(ratio), max_slice_nums)

best_grid = None

if multiple > 1 and not never_split:
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i == 1 or i > max_slice_nums:
continue
candidate_split_grids_nums.append(i)

# source image, down-sampling and ensure divided by patch_size
candidate_grids = []

# find best grid
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1

best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error

return best_grid


def get_grid_placeholder(grid, query_num):
image_placeholder = query_num + 2

cols = grid[0]
rows = grid[1]
slices = 0
for i in range(rows):
lines = 0
for j in range(cols):
lines += image_placeholder
if i < rows - 1:
slices += lines + 1
else:
slices += lines
slice_placeholder = 2 + slices
return slice_placeholder


class MiniCPMV_VLLM:

def __init__(self) -> None:
self.config = AutoConfig.from_pretrained('openbmb/MiniCPM-V-2',
trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V-2',
trust_remote_code=True)
self.llm = LLM(
model="openbmb/MiniCPM-V-2",
image_input_type="pixel_values",
image_token_id=101,
image_input_shape="1,3,448,448",
image_feature_size=64,
gpu_memory_utilization=0.75,
trust_remote_code=True,
)

def get_slice_image_placeholder(self, image):
image_placeholder = self.config.query_num + 2

best_grid = slice_image(
image,
self.config.max_slice_nums,
self.config.scale_resolution,
self.config.patch_size,
)
final_placeholder = image_placeholder

if best_grid is not None:
final_placeholder += get_grid_placeholder(best_grid,
self.config.query_num)

return final_placeholder - 1

def generate(self, image, question, sampling_params):
addtion_tokens = self.get_slice_image_placeholder(image)
image = transforms.Compose([transforms.ToTensor()])(img=image)
images = torch.stack([image])

prompt = "<用户><image></image>" + \
question + \
"<AI>" + '<unk>' * addtion_tokens

outputs = self.llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE,
data=images),
sampling_params=sampling_params)
return outputs[0].outputs[0].text


if __name__ == '__main__':
model = MiniCPMV_VLLM()

sampling_params = SamplingParams(
temperature=0.7,
top_p=0.8,
top_k=100,
seed=3472,
max_tokens=1024,
min_tokens=150,
# temperature=0,
# use_beam_search=True,
# length_penalty=1.2,
# best_of=3
)

image = Image.open('./example.png').convert('RGB')
question = "Provide an intricate description of the image."
response = model.generate(image, question, sampling_params)
print(response)
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ numpy
requests
py-cpuinfo
transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3.
timm==0.9.10
HwwwwwwwH marked this conversation as resolved.
Show resolved Hide resolved
tokenizers >= 0.19.1 # Required for Llama 3.
fastapi
openai
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,10 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
input_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, input_embeds)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down
Loading
Loading