Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add finetune code for Yi-VL model #368

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
<p align="left">
&nbspEnglish&nbsp | &nbsp; <a href="README_CN.md">中文</a>
</p>
<br><br>

<div align="center">

<picture>
Expand All @@ -24,6 +29,8 @@

</div>

<div id="top"></div>

<div align="center">
<h3 align="center">Building the Next Generation of Open-Source and Bilingual LLMs</h3>
</div>
Expand Down Expand Up @@ -126,7 +133,9 @@
</ul>
</details>

<div align="right"> [ <a href="#building-the-next-generation-of-open-source-and-bilingual-llms">Back to top ⬆️ </a> ] </div>
<p align="right"> [
<a href="#top">Back to top ⬆️ </a> ]
</p>

## 🎉 News

Expand Down Expand Up @@ -180,7 +189,9 @@ sequence length and can be extended to 32K during inference time.

</details>

<div align="right"> [ <a href="#building-the-next-generation-of-open-source-and-bilingual-llms">Back to top ⬆️ </a> ] </div>
<p align="right"> [
<a href="#top">Back to top ⬆️ </a> ]
</p>

## 🎯 Models

Expand Down Expand Up @@ -241,7 +252,9 @@ Yi-6B-200K | • [🤗 Hugging Face](https://huggingface.co/01-ai/Yi-6B-200K)
</ul>
</details>

<div align="right"> [ <a href="#building-the-next-generation-of-open-source-and-bilingual-llms">Back to top ⬆️ </a> ] </div>
<p align="right"> [
<a href="#top">Back to top ⬆️ </a> ]
</p>


# 🟢 How to use Yi?
Expand Down Expand Up @@ -315,7 +328,7 @@ If you want to chat with Yi with more customizable options (e.g., system prompt,

This tutorial guides you through every step of running **Yi-34B-Chat locally on an A800 (80G)** and then performing inference.

#### Step 0: Prerequistes
#### Step 0: Prerequisites

- Make sure Python 3.10 or a later version is installed.

Expand Down Expand Up @@ -829,7 +842,9 @@ python eval_quantized_model.py --model /quantized_model --trust_remote_code

</ul>
</details>
<div align="right"> [ <a href="#building-the-next-generation-of-open-source-and-bilingual-llms">Back to top ⬆️ </a> ] </div>
<p align="right"> [
<a href="#top">Back to top ⬆️ </a> ]
</p>

### Deployment

Expand Down Expand Up @@ -903,6 +918,7 @@ With all these resources at your fingertips, you're ready to start your exciting
| Blog | [Running Yi-34B-Chat locally using LlamaEdge](https://www.secondstate.io/articles/yi-34b/) | 2023-11-30 | [Second State](https://github.com/second-state) |
| Blog | [零一万物模型折腾笔记:官方 Yi-34B 模型基础使用](https://zhuanlan.zhihu.com/p/671387298) | 2023-12-10 | [苏洋](https://github.com/soulteary) |
| Blog | [CPU 混合推理,非常见大模型量化方案:“二三五六” 位量化方案](https://zhuanlan.zhihu.com/p/671698216) | 2023-12-12 | [苏洋](https://github.com/soulteary) |
| Blog | [零一万物开源Yi-VL多模态大模型,魔搭社区推理&微调最佳实践来啦!](https://zhuanlan.zhihu.com/p/680098411) | 2024-01-26 | [ModelScope](https://github.com/modelscope) |
| Video | [只需 24G 显存,用 vllm 跑起来 Yi-34B 中英双语大模型](https://www.bilibili.com/video/BV17t4y1f7Ee/) | 2023-12-28 | 漆妮妮 |
| Video | [Install Yi 34B Locally - Chinese English Bilingual LLM](https://www.youtube.com/watch?v=CVQvj4Wrh4w&t=476s) | 2023-11-05 | Fahd Mirza |
</details>
Expand Down Expand Up @@ -998,7 +1014,9 @@ If you're seeking to explore the diverse capabilities within Yi's thriving famil
- [amazing-openai-api](https://github.com/soulteary/amazing-openai-api): this tool converts Yi model APIs into the OpenAI API format out of the box.
- [LlamaEdge](https://www.secondstate.io/articles/yi-34b/#create-an-openai-compatible-api-service-for-the-yi-34b-chat-model): this tool builds an OpenAI-compatible API server for Yi-34B-Chat using a portable Wasm (WebAssembly) file, powered by Rust.

<div align="right"> [ <a href="#building-the-next-generation-of-open-source-and-bilingual-llms">Back to top ⬆️ </a> ] </div>
<p align="right"> [
<a href="#top">Back to top ⬆️ </a> ]
</p>

## 📌 Benchmarks

Expand All @@ -1024,7 +1042,7 @@ Yi-34B-Chat model demonstrates exceptional performance, ranking first among all

### 📊 Base model performance

The Yi-34B and Yi-34B-200K models stand out as the top performers among open-source models, especially excelling in MMLU, CMML, common-sense reasoning, reading comprehension, and more.
The Yi-34B and Yi-34B-200K models stand out as the top performers among open-source models, especially excelling in MMLU, CMMLU, common-sense reasoning, reading comprehension, and more.

![Base model performance](https://github.com/01-ai/Yi/blob/main/assets/img/benchmark_base.png?raw=true)

Expand All @@ -1048,7 +1066,9 @@ Everyone! 🙌 ✅

- For free commercial use, you only need to [complete this form](https://www.lingyiwanwu.com/yi-license) to get a Yi Model Commercial License.

<div align="right"> [ <a href="#building-the-next-generation-of-open-source-and-bilingual-llms">Back to top ⬆️ </a> ] </div>
<p align="right"> [
<a href="#top">Back to top ⬆️ </a> ]
</p>

# 🟢 Misc.

Expand All @@ -1058,7 +1078,9 @@ A heartfelt thank you to each of you who have made contributions to the Yi commu

[![yi contributors](https://contrib.rocks/image?repo=01-ai/yi&max=2000&columns=15)](https://github.com/01-ai/yi/graphs/contributors)

<div align="right"> [ <a href="#building-the-next-generation-of-open-source-and-bilingual-llms">Back to top ⬆️ </a> ] </div>
<p align="right"> [
<a href="#top">Back to top ⬆️ </a> ]
</p>

### 📡 Disclaimer

Expand All @@ -1071,12 +1093,16 @@ problematic outputs. We will not be responsible for any risks and issues
resulting from misuse, misguidance, illegal usage, and related misinformation,
as well as any associated data security concerns.

<div align="right"> [ <a href="#building-the-next-generation-of-open-source-and-bilingual-llms">Back to top ⬆️ </a> ] </div>
<p align="right"> [
<a href="#top">Back to top ⬆️ </a> ]
</p>

### 🪪 License

The source code in this repo is licensed under the [Apache 2.0
license](https://github.com/01-ai/Yi/blob/main/LICENSE). The Yi series models are fully open for academic research and free for commercial use, with automatic permission granted upon application. All usage must adhere to the [Yi Series Models Community License Agreement 2.1](https://github.com/01-ai/Yi/blob/main/MODEL_LICENSE_AGREEMENT.txt).
For free commercial use, you only need to send an email to [get official commercial permission](https://www.lingyiwanwu.com/yi-license).

<div align="right"> [ <a href="#building-the-next-generation-of-open-source-and-bilingual-llms">Back to top ⬆️ </a> ] </div>
<p align="right"> [
<a href="#top">Back to top ⬆️ </a> ]
</p>
15 changes: 9 additions & 6 deletions README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,26 @@
</div>

<div align="center">
<h3 align="center">Building the Next Generation of Open-Source and Bilingual LLMs</h3>
<h3 align="center">打造新一代开源双语大语言模型</h3>
</div>

<p align="center">
🤗 <a href="https://huggingface.co/01-ai" target="_blank">Hugging Face</a> • 🤖 <a href="https://www.modelscope.cn/organization/01ai/" target="_blank">ModelScope</a> • ✡️ <a href="https://wisemodel.cn/organization/01.AI" target="_blank">WiseModel</a>
🤗 <a href="https://huggingface.co/01-ai" target="_blank">Hugging Face</a> • 🤖 <a href="https://www.modelscope.cn/organization/01ai/" target="_blank">魔搭社区 ModelScope</a> • ✡️ <a href="https://wisemodel.cn/organization/01.AI" target="_blank">始智AI WiseModel</a>
</p>

<p align="center">
👩‍🚀 Ask questions or discuss ideas on <a href="https://github.com/01-ai/Yi/discussions" target="_blank"> GitHub </a>!
👩‍🚀 欢迎你来 <a href="https://github.com/01-ai/Yi/discussions" target="_blank"> GitHub </a> 提问讨论
</p>

<p align="center">
👋 Join us on 💬 <a href="https://github.com/01-ai/Yi/issues/43#issuecomment-1827285245" target="_blank"> WeChat (Chinese) </a>!
👋 欢迎你加入我们的 💬 <a href="https://github.com/01-ai/Yi/issues/43#issuecomment-1827285245" target="_blank"> 微信群 </a>一起交流
</p>

<p align="center">
📚 Grow at <a href="#learning-hub">Yi Learning Hub</a>!
</p>

📚 欢迎你来 <a href="#learning-hub"> Yi 学习俱乐部 </a>探索新知
</p>

<hr>

<ul>
Expand Down Expand Up @@ -895,6 +897,7 @@ Yi 8-bit quantized models | [GPTQ and CUDA](https://github.com/PanQiWei/AutoGPT
| 博客 | [Running Yi-34B-Chat locally using LlamaEdge](https://www.secondstate.io/articles/yi-34b/) | 2023-11-30 | [Second State](https://github.com/second-state) |
| 博客 | [零一万物模型折腾笔记:官方 Yi-34B 模型基础使用](https://zhuanlan.zhihu.com/p/671387298) | 2023-12-10 | [苏洋](https://github.com/soulteary) |
| 博客 | [CPU 混合推理,非常见大模型量化方案:“二三五六” 位量化方案](https://zhuanlan.zhihu.com/p/671698216) | 2023-12-12 | [苏洋](https://github.com/soulteary) |
| 博客 | [零一万物开源Yi-VL多模态大模型,魔搭社区推理&微调最佳实践来啦!](https://zhuanlan.zhihu.com/p/680098411) | 2024-01-26 | [ModelScope](https://github.com/modelscope) |
| 视频 | [只需 24G 显存,用 vllm 跑起来 Yi-34B 中英双语大模型](https://www.bilibili.com/video/BV17t4y1f7Ee/) | 2023-12-28 | 漆妮妮 |
| 视频 | [Install Yi 34B Locally - Chinese English Bilingual LLM](https://www.youtube.com/watch?v=CVQvj4Wrh4w&t=476s) | 2023-11-05 | Fahd Mirza |
</details>
Expand Down
34 changes: 34 additions & 0 deletions VL/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,42 @@ Human: what are they eating
Assistant: cat food
```

## Finetuning
1. Prepare data

Prepare your own data into the following JSON format.
```json
[
{
"image": "images/cat.jpg",
"conversations": [
{
"from": "human",
"value": "<image_placeholder>\nDescribe the cats and what they are doing in detail."
},
{
"from": "assistant",
"value": "In the image, there are three cats situated on a stone floor. The cat on the left is a calico cat, its coat a mix of white, orange, and black. It's eating from a metal bowl. In the middle, there's a gray cat, also eating from a metal bowl. On the right, there's a black cat, eating from a plastic bowl. The cats are all facing away from the camera, engrossed in their meal. The stone floor they're on is gray, and a concrete wall forms the backdrop of the scene. The image captures a peaceful moment of these cats enjoying their food."
},
]
},
...
]
```

2. Finetune Yi-VL

Training scripts are provided in the `scripts` folder. You can use `scripts/finetune.sh`, `scripts/finetune_lora.sh` or `scripts/finetune_qlora.sh` to finetune Yi-VL with your own dataset.

Before running the scrips, you should specify the following parameters.
- `--model_name_or_path`: the path to Yi-VL model; you can use 6B or 34B model.
- `--data_path`: the path to your own dataset.
- `--image_folder`: the path to the image data folder.
- `--vision_tower`: the path to the ViT model, usually found in the Yi-VL base model folder.

3. Merge lora (Optional)

If you use `lora` or `qlora` for finetuning, you need to merge the lora parameters into the Yi-VL model after finetuning. You can use `scripts/merge_lora.sh` to merge the lora parameters.

## Major difference with LLaVA
1. We change the image token from ```<image>``` to ```<image_placeholder>```. The system prompt is modified to:
Expand Down
15 changes: 14 additions & 1 deletion VL/llava/mm_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import base64
import os
from io import BytesIO

import torch
Expand Down Expand Up @@ -70,7 +71,7 @@ def get_model_name_from_path(model_path):


def load_pretrained_model(
model_path, load_8bit=False, load_4bit=False, device_map="auto", multimodal="IMAGE"
model_path, lora_path=None, load_8bit=False, load_4bit=False, device_map="auto", multimodal="IMAGE"
):
kwargs = {"device_map": device_map}
kwargs["torch_dtype"] = torch.bfloat16
Expand All @@ -79,6 +80,18 @@ def load_pretrained_model(
model = LlavaLlamaForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
if lora_path is not None:
from peft import PeftModel
non_lora_trainables = torch.load(os.path.join(lora_path, 'non_lora_trainables.bin'), map_location='cpu')
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
non_lora_trainables.items()}
if any(k.startswith('model.model.') for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
model.load_state_dict(non_lora_trainables, strict=False)

model = PeftModel.from_pretrained(model, lora_path)
model = model.merge_and_unload()

image_processor = None
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
Expand Down
4 changes: 1 addition & 3 deletions VL/llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def __init__(self, config):
super(LlavaMetaModel, self).__init__(config)

if hasattr(config, "mm_vision_tower"):
config.mm_vision_tower = os.path.join(
key_info["model_path"], config.mm_vision_tower.replace("./", "")
)
config.mm_vision_tower = config.mm_vision_tower
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector = build_vision_projector(config)

Expand Down
115 changes: 115 additions & 0 deletions VL/llava/train/llama_flash_attn_monkey_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Optional, Tuple
import warnings

import torch

import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv

try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input


def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)

bsz, q_len, _ = hidden_states.size()

query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
) # shape: (b, num_heads, s, head_dim)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)

if past_key_value is not None:
# reuse k, v
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# Transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states], dim=2)
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
key_padding_mask = attention_mask

if key_padding_mask is None:
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
max_s = q_len
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = output.view(bsz, q_len, -1)
else:
qkv = qkv.reshape(bsz, q_len, -1)
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
output_unpad = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)

return self.o_proj(output), None, past_key_value


# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask


def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
Loading