Skip to content

Commit

Permalink
Support new features of MiniCPM-V (#6626)
Browse files Browse the repository at this point in the history
* fix template name

* tiny fix

* support minicpm-o-2.6
  • Loading branch information
BUAADreamer authored Jan 13, 2025
1 parent e3e2c8c commit c3fda50
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 168 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_v |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
Expand Down
2 changes: 1 addition & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272
| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next |
| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video |
| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 |
| [MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | cpm_v |
| [MiniCPM-o-2.6/MiniCPM-V-2.6](https://huggingface.co/openbmb) | 8B | minicpm_v |
| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral |
| [OLMo](https://huggingface.co/allenai) | 1B/7B | - |
| [PaliGemma/PaliGemma2](https://huggingface.co/google) | 3B/10B/28B | paligemma |
Expand Down
10 changes: 10 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ def get_console_scripts() -> List[str]:
"badam": ["badam>=1.2.1"],
"adam-mini": ["adam-mini"],
"qwen": ["transformers_stream_generator"],
"minicpm_v": [
"soundfile",
"torchvision",
"torchaudio",
"vector_quantize_pytorch",
"vocos",
"msgpack",
"referencing",
"jsonschema_specifications",
],
"modelscope": ["modelscope"],
"openmind": ["openmind"],
"swanlab": ["swanlab"],
Expand Down
5 changes: 2 additions & 3 deletions src/llamafactory/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,8 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tenso
features = features.data # use default_collate() instead of BatchEncoding.to()

if "image_bound" in features: # for minicpmv inputs
features["position_ids"] = (
torch.arange(features["input_ids"].size(1)).long().unsqueeze(0).expand_as(features["input_ids"])
)
bsz, seq_length = features["input_ids"].shape
features["position_ids"] = torch.arange(seq_length).long().repeat(bsz, 1)
return {"data": features, "labels": features["labels"]}

return features
Expand Down
302 changes: 151 additions & 151 deletions src/llamafactory/data/mm_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,156 +254,6 @@ def get_mm_inputs(
return {}


class CpmVPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")

if len(videos) != 0:
max_slice_nums = 2
use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, processor)
else:
max_slice_nums = image_processor.max_slice_nums
use_image_id = image_processor.use_image_id

for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
num_image_tokens += 1
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)

while VIDEO_PLACEHOLDER in content:
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1

message["content"] = content.replace("{{image}}", "(<image>./</image>)")

if num_image_tokens > 0:
mm_inputs = self._get_mm_inputs(images, [], processor)

if mm_inputs:
pattern = "(<image>./</image>)"
image_sizes = mm_inputs["image_sizes"]

for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
final_text = (
final_text
+ text_chunks[i]
+ image_processor.get_slice_image_placeholder(
image_sizes[0][i], i, max_slice_nums, use_image_id
)
)
final_text += text_chunks[-1]
messages[index]["content"] = final_text

if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")

if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")

return messages

@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")

mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
)
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
new_images = []
idx = 0
for valid_image_nums in valid_image_nums_ls:
new_images.append(images[idx : idx + valid_image_nums])
idx += valid_image_nums

images = new_images

image_inputs = image_processor(
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
)
mm_inputs.update(image_inputs)

if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
)
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
mm_inputs.update(video_inputs)

return mm_inputs

@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
image_bounds_list = []
valid_image_nums_ls = []
for input_ids in batch_ids:
input_ids_ = torch.tensor(input_ids)
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
input_ids_ == processor.tokenizer.slice_start_id
)
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
valid_image_nums_ls.append(valid_image_nums)
image_bounds = torch.hstack(
[
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
]
)
image_bounds_list.append(image_bounds)

mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
mm_inputs.update({"image_bound": image_bounds_list})
return mm_inputs


class LlavaPlugin(BasePlugin):
@override
def process_messages(
Expand Down Expand Up @@ -567,6 +417,156 @@ def get_mm_inputs(
return self._get_mm_inputs(images, videos, processor)


class MiniCPMVPlugin(BasePlugin):
@override
def process_messages(
self,
messages: Sequence[Dict[str, str]],
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: Optional["ProcessorMixin"],
) -> List[Dict[str, str]]:
self._validate_input(images, videos)
num_image_tokens = 0
num_video_tokens = 0
messages = deepcopy(messages)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")

if len(videos) != 0:
max_slice_nums = 2
use_image_id = False
mm_inputs = self._get_mm_inputs([], videos, processor)
else:
max_slice_nums = image_processor.max_slice_nums
use_image_id = image_processor.use_image_id

for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
num_image_tokens += 1

while VIDEO_PLACEHOLDER in content:
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
num_video_tokens += 1

message["content"] = content.replace("{{image}}", "(<image>./</image>)")

if num_image_tokens > 0:
mm_inputs = self._get_mm_inputs(images, [], processor)

if mm_inputs:
pattern = "(<image>./</image>)"
image_sizes = mm_inputs["image_sizes"]

for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(image_tags)):
final_text = (
final_text
+ text_chunks[i]
+ image_processor.get_slice_image_placeholder(
image_sizes[0][i], i, max_slice_nums, use_image_id
)
)

final_text += text_chunks[-1]
messages[index]["content"] = final_text

if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")

if len(videos) != num_video_tokens:
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")

return messages

@override
def _get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
processor: "ProcessorMixin",
**kwargs,
) -> Dict[str, "torch.Tensor"]:
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
)
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
new_images = []
idx = 0
for valid_image_nums in valid_image_nums_ls:
new_images.append(images[idx : idx + valid_image_nums])
idx += valid_image_nums

images = new_images

image_inputs = image_processor(
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
)
mm_inputs.update(image_inputs)

if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
)
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
mm_inputs.update(video_inputs)

return mm_inputs

@override
def get_mm_inputs(
self,
images: Sequence["ImageInput"],
videos: Sequence["VideoInput"],
imglens: Sequence[int],
vidlens: Sequence[int],
batch_ids: Sequence[List[int]],
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
image_bounds_list = []
valid_image_nums_ls = []
for input_ids in batch_ids:
input_ids_ = torch.tensor(input_ids)
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
input_ids_ == processor.tokenizer.slice_start_id
)
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
image_start_tokens = torch.where(start_cond)[0]
image_start_tokens += 1
image_end_tokens = torch.where(end_cond)[0]
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
valid_image_nums_ls.append(valid_image_nums)
image_bounds = torch.hstack(
[
image_start_tokens[:valid_image_nums].unsqueeze(-1),
image_end_tokens[:valid_image_nums].unsqueeze(-1),
]
)
image_bounds_list.append(image_bounds)

mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
mm_inputs.update({"image_bound": image_bounds_list})
return mm_inputs


class PaliGemmaPlugin(BasePlugin):
@override
def process_messages(
Expand Down Expand Up @@ -945,10 +945,10 @@ def get_mm_inputs(

PLUGINS = {
"base": BasePlugin,
"cpm_v": CpmVPlugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,
"minicpm_v": MiniCPMVPlugin,
"paligemma": PaliGemmaPlugin,
"pixtral": PixtralPlugin,
"qwen2_vl": Qwen2vlPlugin,
Expand Down
Loading

0 comments on commit c3fda50

Please sign in to comment.