-
Notifications
You must be signed in to change notification settings - Fork 169
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[PPMix No.10】 add minicpmv-2_6 (#825)
Co-authored-by: luyao-cv <1367355728>
- Loading branch information
Showing
13 changed files
with
3,056 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# from paddlenlp.transformers import AutoTokenizer | ||
from paddlenlp.transformers import Qwen2Tokenizer | ||
from paddlemix.models.minicpm_v.tokenization_minicpmv_fast import MiniCPMVTokenizerFast | ||
from paddlemix.models.minicpm_v.modeling_minicpmv import MiniCPMV | ||
from PIL import Image | ||
MODEL_NAME = "openbmb/MiniCPM-V-2_6" | ||
model = MiniCPMV.from_pretrained(MODEL_NAME, dtype="bfloat16") | ||
model = model.eval() | ||
tokenizer = MiniCPMVTokenizerFast.from_pretrained(MODEL_NAME) | ||
image = Image.open('paddlemix/demo_images/c89b9daf907cb47481e6a1f77.jpg').convert('RGB') | ||
|
||
question = "识别图中所有文字,无需添加标点。" | ||
|
||
msgs = [{'role': 'user', 'content': [image, question]}] | ||
|
||
res = model.chat( | ||
image=None, | ||
msgs=msgs, | ||
tokenizer=tokenizer, | ||
max_new_tokens=2048, # 2048 | ||
) | ||
print(res) | ||
|
||
## if you want to use streaming, please make sure sampling=True and stream=True | ||
## the model.chat will return a generator | ||
res = model.chat( | ||
image=None, | ||
msgs=msgs, | ||
tokenizer=tokenizer, | ||
sampling=True, | ||
stream=False, | ||
) | ||
|
||
generated_text = "" | ||
for new_text in res: | ||
generated_text += new_text | ||
print(new_text, flush=True, end='') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .bert_padding import * | ||
from .configuration_minicpm import * | ||
from .modeling_minicpmv import * | ||
from .modeling_navit_siglip import * | ||
from .resampler import * | ||
from .tokenization_minicpmv_fast import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# reference from Dao-AILAB flash-attn | ||
# https://github.com/Dao-AILab/flash-attention/blob/74b0761ff7efc7b90d4e5aeb529c1b2a09a7458c/flash_attn/bert_padding.py#L38 | ||
import paddle | ||
import paddle.nn.functional as F | ||
from einops import rearrange, repeat | ||
from functools import reduce | ||
import operator | ||
|
||
|
||
class IndexFirstAxis(paddle.autograd.PyLayer): | ||
|
||
@staticmethod | ||
def forward(ctx, input, indices): | ||
ctx.save_for_backward(indices) | ||
assert input.ndim >= 2 | ||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] | ||
second_dim = reduce(operator.mul, other_shape, 1) | ||
return paddle.take_along_axis( | ||
arr=rearrange(input, 'b ... -> b (...)'), | ||
axis=0, | ||
indices=repeat(indices, 'z -> z d', d=second_dim) | ||
).reshape([-1, *other_shape]) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
"""Class Attribute: torch.autograd.function.FunctionCtx.saved_tensors, can not convert, please check whether it is torch.Tensor.*/torch.autograd.function.FunctionCtx.*/torch.distributions.Distribution.* and convert manually""" | ||
(indices,) = ctx.saved_tensor() | ||
assert grad_output.ndim >= 2 | ||
other_shape = grad_output.shape[1:] | ||
grad_output = rearrange(grad_output, 'b ... -> b (...)') | ||
grad_input = paddle.zeros(shape=[ctx.first_axis_dim, tuple( | ||
grad_output.shape)[1]], dtype=grad_output.dtype) | ||
|
||
grad_input.put_along_axis_( | ||
axis=0, | ||
indices=repeat(indices, 'z -> z d', d=tuple(grad_output.shape)[1]), | ||
values=grad_output, | ||
) | ||
return grad_input.reshape([ctx.first_axis_dim, *other_shape]), None | ||
|
||
|
||
index_first_axis = IndexFirstAxis.apply | ||
|
||
|
||
class IndexPutFirstAxis(paddle.autograd.PyLayer): | ||
|
||
@staticmethod | ||
def forward(ctx, values, indices, first_axis_dim): | ||
ctx.save_for_backward(indices) | ||
assert indices.ndim == 1 | ||
assert values.ndim >= 2 | ||
output = paddle.zeros(shape=[first_axis_dim, *tuple(values.shape)[1 | ||
:]], dtype=values.dtype) | ||
output[indices] = values | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
"""Class Attribute: torch.autograd.function.FunctionCtx.saved_tensors, can not convert, please check whether it is torch.Tensor.*/torch.autograd.function.FunctionCtx.*/torch.distributions.Distribution.* and convert manually""" | ||
(indices,) = ctx.saved_tensor() | ||
grad_values = grad_output[indices] | ||
return grad_values, None, None | ||
|
||
|
||
index_put_first_axis = IndexPutFirstAxis.apply | ||
|
||
|
||
def unpad_input(hidden_states, attention_mask): | ||
""" | ||
Arguments: | ||
hidden_states: (batch, seqlen, ...) | ||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. | ||
Return: | ||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. | ||
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. | ||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. | ||
max_seqlen_in_batch: int | ||
""" | ||
seqlens_in_batch = paddle.sum(attention_mask, axis=-1, dtype="int32") | ||
indices = paddle.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = paddle.max(seqlens_in_batch).item() | ||
cu_seqlens = F.pad(paddle.cumsum(seqlens_in_batch, axis=0), [1, 0]) | ||
|
||
return ( | ||
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
|
||
def pad_input(hidden_states, indices, batch, seqlen): | ||
""" | ||
Arguments: | ||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. | ||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. | ||
batch: int, batch size for the padded sequence. | ||
seqlen: int, maximum sequence length for the padded sequence. | ||
Return: | ||
hidden_states: (batch, seqlen, ...) | ||
""" | ||
output = index_put_first_axis(hidden_states, indices, batch * seqlen) | ||
return rearrange(output, "(b s) ... -> b s ...", b=batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
""" MiniCPMV model configuration""" | ||
from typing import Union | ||
from .modeling_navit_siglip import SigLipVisionConfig | ||
from paddlenlp.transformers import PretrainedConfig | ||
from paddlenlp.transformers import Qwen2Config | ||
from paddlemix.utils.log import logger | ||
|
||
class MiniCPMVSliceConfig(PretrainedConfig): | ||
model_type = 'minicpmv' | ||
|
||
def __init__(self, patch_size=14, max_slice_nums=9, scale_resolution= | ||
448, **kwargs): | ||
super().__init__(**kwargs) | ||
self.patch_size = patch_size | ||
self.max_slice_nums = max_slice_nums | ||
self.scale_resolution = scale_resolution | ||
|
||
@classmethod | ||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os. | ||
PathLike], **kwargs) ->'PretrainedConfig': | ||
cls._set_token_in_kwargs(kwargs) | ||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path | ||
, **kwargs) | ||
if config_dict.get('model_type') == 'minicpmv': | ||
config_dict = config_dict['slice_config'] | ||
if 'model_type' in config_dict and hasattr(cls, 'model_type' | ||
) and config_dict['model_type'] != cls.model_type: | ||
logger.warning( | ||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type {cls.model_type}. This is not supported for all configurations of models and can yield errors." | ||
) | ||
return cls.from_dict(config_dict, **kwargs) | ||
|
||
|
||
class MiniCPMVConfig(Qwen2Config): | ||
model_type = 'minicpmv' | ||
keys_to_ignore_at_inference = ['past_key_values'] | ||
default_vision_config = {'hidden_size': 1152, 'image_size': 980, | ||
'intermediate_size': 4304, 'model_type': 'siglip', | ||
'num_attention_heads': 16, 'num_hidden_layers': 27, 'patch_size': 14} | ||
|
||
def __init__(self, use_cache=True, query_num=64, image_size=448, | ||
drop_vision_last_layer=True, batch_vision_input=True, slice_config= | ||
None, vision_config=None, use_image_id=True, **kwargs): | ||
self.use_cache = use_cache | ||
self.query_num = query_num | ||
self.image_size = image_size | ||
self.drop_vision_last_layer = drop_vision_last_layer | ||
self.batch_vision_input = batch_vision_input | ||
self.use_image_id = use_image_id | ||
if slice_config is None: | ||
self.slice_config = MiniCPMVSliceConfig(max_slice_nums=1) | ||
else: | ||
self.slice_config = MiniCPMVSliceConfig(**slice_config) | ||
self.slice_mode = True | ||
if vision_config is None: | ||
self.vision_config = SigLipVisionConfig(**self. | ||
default_vision_config) | ||
logger.info('vision_config is None, using default vision config') | ||
elif isinstance(vision_config, dict): | ||
self.vision_config = SigLipVisionConfig(**vision_config) | ||
elif isinstance(vision_config, SigLipVisionConfig): | ||
self.vision_config = vision_config | ||
self.patch_size = self.vision_config.patch_size | ||
super().__init__(**kwargs) |
Oops, something went wrong.