diff --git a/README.md b/README.md
index 595199aa..48d823aa 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,8 @@
+
+  English  | 中文
+
+
+
+
+
Building the Next Generation of Open-Source and Bilingual LLMs
@@ -126,7 +133,9 @@
-
+ [
+ Back to top ⬆️ ]
+
## 🎉 News
@@ -180,7 +189,9 @@ sequence length and can be extended to 32K during inference time.
-
+ [
+ Back to top ⬆️ ]
+
## 🎯 Models
@@ -241,7 +252,9 @@ Yi-6B-200K | • [🤗 Hugging Face](https://huggingface.co/01-ai/Yi-6B-200K)
-
+ [
+ Back to top ⬆️ ]
+
# 🟢 How to use Yi?
@@ -315,7 +328,7 @@ If you want to chat with Yi with more customizable options (e.g., system prompt,
This tutorial guides you through every step of running **Yi-34B-Chat locally on an A800 (80G)** and then performing inference.
-#### Step 0: Prerequistes
+#### Step 0: Prerequisites
- Make sure Python 3.10 or a later version is installed.
@@ -829,7 +842,9 @@ python eval_quantized_model.py --model /quantized_model --trust_remote_code
-
+ [
+ Back to top ⬆️ ]
+
### Deployment
@@ -903,6 +918,7 @@ With all these resources at your fingertips, you're ready to start your exciting
| Blog | [Running Yi-34B-Chat locally using LlamaEdge](https://www.secondstate.io/articles/yi-34b/) | 2023-11-30 | [Second State](https://github.com/second-state) |
| Blog | [零一万物模型折腾笔记:官方 Yi-34B 模型基础使用](https://zhuanlan.zhihu.com/p/671387298) | 2023-12-10 | [苏洋](https://github.com/soulteary) |
| Blog | [CPU 混合推理,非常见大模型量化方案:“二三五六” 位量化方案](https://zhuanlan.zhihu.com/p/671698216) | 2023-12-12 | [苏洋](https://github.com/soulteary) |
+| Blog | [零一万物开源Yi-VL多模态大模型,魔搭社区推理&微调最佳实践来啦!](https://zhuanlan.zhihu.com/p/680098411) | 2024-01-26 | [ModelScope](https://github.com/modelscope) |
| Video | [只需 24G 显存,用 vllm 跑起来 Yi-34B 中英双语大模型](https://www.bilibili.com/video/BV17t4y1f7Ee/) | 2023-12-28 | 漆妮妮 |
| Video | [Install Yi 34B Locally - Chinese English Bilingual LLM](https://www.youtube.com/watch?v=CVQvj4Wrh4w&t=476s) | 2023-11-05 | Fahd Mirza |
@@ -998,7 +1014,9 @@ If you're seeking to explore the diverse capabilities within Yi's thriving famil
- [amazing-openai-api](https://github.com/soulteary/amazing-openai-api): this tool converts Yi model APIs into the OpenAI API format out of the box.
- [LlamaEdge](https://www.secondstate.io/articles/yi-34b/#create-an-openai-compatible-api-service-for-the-yi-34b-chat-model): this tool builds an OpenAI-compatible API server for Yi-34B-Chat using a portable Wasm (WebAssembly) file, powered by Rust.
-
+ [
+ Back to top ⬆️ ]
+
## 📌 Benchmarks
@@ -1024,7 +1042,7 @@ Yi-34B-Chat model demonstrates exceptional performance, ranking first among all
### 📊 Base model performance
-The Yi-34B and Yi-34B-200K models stand out as the top performers among open-source models, especially excelling in MMLU, CMML, common-sense reasoning, reading comprehension, and more.
+The Yi-34B and Yi-34B-200K models stand out as the top performers among open-source models, especially excelling in MMLU, CMMLU, common-sense reasoning, reading comprehension, and more.
![Base model performance](https://github.com/01-ai/Yi/blob/main/assets/img/benchmark_base.png?raw=true)
@@ -1048,7 +1066,9 @@ Everyone! 🙌 ✅
- For free commercial use, you only need to [complete this form](https://www.lingyiwanwu.com/yi-license) to get a Yi Model Commercial License.
-
+ [
+ Back to top ⬆️ ]
+
# 🟢 Misc.
@@ -1058,7 +1078,9 @@ A heartfelt thank you to each of you who have made contributions to the Yi commu
[![yi contributors](https://contrib.rocks/image?repo=01-ai/yi&max=2000&columns=15)](https://github.com/01-ai/yi/graphs/contributors)
-
+ [
+ Back to top ⬆️ ]
+
### 📡 Disclaimer
@@ -1071,7 +1093,9 @@ problematic outputs. We will not be responsible for any risks and issues
resulting from misuse, misguidance, illegal usage, and related misinformation,
as well as any associated data security concerns.
-
+ [
+ Back to top ⬆️ ]
+
### 🪪 License
@@ -1079,4 +1103,6 @@ The source code in this repo is licensed under the [Apache 2.0
license](https://github.com/01-ai/Yi/blob/main/LICENSE). The Yi series models are fully open for academic research and free for commercial use, with automatic permission granted upon application. All usage must adhere to the [Yi Series Models Community License Agreement 2.1](https://github.com/01-ai/Yi/blob/main/MODEL_LICENSE_AGREEMENT.txt).
For free commercial use, you only need to send an email to [get official commercial permission](https://www.lingyiwanwu.com/yi-license).
-
+ [
+ Back to top ⬆️ ]
+
diff --git a/README_CN.md b/README_CN.md
index 3a9b6a65..eb346810 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -25,24 +25,26 @@
-
Building the Next Generation of Open-Source and Bilingual LLMs
+ 打造新一代开源双语大语言模型
-🤗 Hugging Face • 🤖 ModelScope • ✡️ WiseModel
+🤗 Hugging Face • 🤖 魔搭社区 ModelScope • ✡️ 始智AI WiseModel
- 👩🚀 Ask questions or discuss ideas on GitHub !
+ 👩🚀 欢迎你来 GitHub 提问讨论
- 👋 Join us on 💬 WeChat (Chinese) !
+ 👋 欢迎你加入我们的 💬 微信群 一起交流
- 📚 Grow at Yi Learning Hub!
-
+
+ 📚 欢迎你来 Yi 学习俱乐部 探索新知
+
+
@@ -895,6 +897,7 @@ Yi 8-bit quantized models | [GPTQ and CUDA](https://github.com/PanQiWei/AutoGPT
| 博客 | [Running Yi-34B-Chat locally using LlamaEdge](https://www.secondstate.io/articles/yi-34b/) | 2023-11-30 | [Second State](https://github.com/second-state) |
| 博客 | [零一万物模型折腾笔记:官方 Yi-34B 模型基础使用](https://zhuanlan.zhihu.com/p/671387298) | 2023-12-10 | [苏洋](https://github.com/soulteary) |
| 博客 | [CPU 混合推理,非常见大模型量化方案:“二三五六” 位量化方案](https://zhuanlan.zhihu.com/p/671698216) | 2023-12-12 | [苏洋](https://github.com/soulteary) |
+| 博客 | [零一万物开源Yi-VL多模态大模型,魔搭社区推理&微调最佳实践来啦!](https://zhuanlan.zhihu.com/p/680098411) | 2024-01-26 | [ModelScope](https://github.com/modelscope) |
| 视频 | [只需 24G 显存,用 vllm 跑起来 Yi-34B 中英双语大模型](https://www.bilibili.com/video/BV17t4y1f7Ee/) | 2023-12-28 | 漆妮妮 |
| 视频 | [Install Yi 34B Locally - Chinese English Bilingual LLM](https://www.youtube.com/watch?v=CVQvj4Wrh4w&t=476s) | 2023-11-05 | Fahd Mirza |
diff --git a/VL/README.md b/VL/README.md
index 6b5c79fb..8bfe9591 100644
--- a/VL/README.md
+++ b/VL/README.md
@@ -56,8 +56,42 @@ Human: what are they eating
Assistant: cat food
```
+## Finetuning
+1. Prepare data
+Prepare your own data into the following JSON format.
+```json
+[
+ {
+ "image": "images/cat.jpg",
+ "conversations": [
+ {
+ "from": "human",
+ "value": "\nDescribe the cats and what they are doing in detail."
+ },
+ {
+ "from": "assistant",
+ "value": "In the image, there are three cats situated on a stone floor. The cat on the left is a calico cat, its coat a mix of white, orange, and black. It's eating from a metal bowl. In the middle, there's a gray cat, also eating from a metal bowl. On the right, there's a black cat, eating from a plastic bowl. The cats are all facing away from the camera, engrossed in their meal. The stone floor they're on is gray, and a concrete wall forms the backdrop of the scene. The image captures a peaceful moment of these cats enjoying their food."
+ },
+ ]
+ },
+ ...
+]
+```
+
+2. Finetune Yi-VL
+
+Training scripts are provided in the `scripts` folder. You can use `scripts/finetune.sh`, `scripts/finetune_lora.sh` or `scripts/finetune_qlora.sh` to finetune Yi-VL with your own dataset.
+
+Before running the scrips, you should specify the following parameters.
+- `--model_name_or_path`: the path to Yi-VL model; you can use 6B or 34B model.
+- `--data_path`: the path to your own dataset.
+- `--image_folder`: the path to the image data folder.
+- `--vision_tower`: the path to the ViT model, usually found in the Yi-VL base model folder.
+
+3. Merge lora (Optional)
+If you use `lora` or `qlora` for finetuning, you need to merge the lora parameters into the Yi-VL model after finetuning. You can use `scripts/merge_lora.sh` to merge the lora parameters.
## Major difference with LLaVA
1. We change the image token from `````` to ``````. The system prompt is modified to:
diff --git a/VL/llava/mm_utils.py b/VL/llava/mm_utils.py
index 1bb61c73..a5e1ff43 100644
--- a/VL/llava/mm_utils.py
+++ b/VL/llava/mm_utils.py
@@ -1,4 +1,5 @@
import base64
+import os
from io import BytesIO
import torch
@@ -70,7 +71,7 @@ def get_model_name_from_path(model_path):
def load_pretrained_model(
- model_path, load_8bit=False, load_4bit=False, device_map="auto", multimodal="IMAGE"
+ model_path, lora_path=None, load_8bit=False, load_4bit=False, device_map="auto", multimodal="IMAGE"
):
kwargs = {"device_map": device_map}
kwargs["torch_dtype"] = torch.bfloat16
@@ -79,6 +80,18 @@ def load_pretrained_model(
model = LlavaLlamaForCausalLM.from_pretrained(
model_path, low_cpu_mem_usage=True, **kwargs
)
+ if lora_path is not None:
+ from peft import PeftModel
+ non_lora_trainables = torch.load(os.path.join(lora_path, 'non_lora_trainables.bin'), map_location='cpu')
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
+ non_lora_trainables.items()}
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
+ model.load_state_dict(non_lora_trainables, strict=False)
+
+ model = PeftModel.from_pretrained(model, lora_path)
+ model = model.merge_and_unload()
+
image_processor = None
model.resize_token_embeddings(len(tokenizer))
vision_tower = model.get_vision_tower()
diff --git a/VL/llava/model/llava_arch.py b/VL/llava/model/llava_arch.py
index 8815515c..9935c6f4 100644
--- a/VL/llava/model/llava_arch.py
+++ b/VL/llava/model/llava_arch.py
@@ -28,9 +28,7 @@ def __init__(self, config):
super(LlavaMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
- config.mm_vision_tower = os.path.join(
- key_info["model_path"], config.mm_vision_tower.replace("./", "")
- )
+ config.mm_vision_tower = config.mm_vision_tower
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector = build_vision_projector(config)
diff --git a/VL/llava/train/llama_flash_attn_monkey_patch.py b/VL/llava/train/llama_flash_attn_monkey_patch.py
new file mode 100644
index 00000000..209e49a3
--- /dev/null
+++ b/VL/llava/train/llama_flash_attn_monkey_patch.py
@@ -0,0 +1,115 @@
+from typing import Optional, Tuple
+import warnings
+
+import torch
+
+import transformers
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
+
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+except ImportError:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
+from flash_attn.bert_padding import unpad_input, pad_input
+
+
+def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ warnings.warn(
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ .transpose(1, 2)
+ ) # shape: (b, num_heads, s, head_dim)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+
+ if past_key_value is not None:
+ # reuse k, v
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # Transform the data into the format required by flash attention
+ qkv = torch.stack([query_states, key_states, value_states], dim=2)
+ qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
+ key_padding_mask = attention_mask
+
+ if key_padding_mask is None:
+ qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
+ max_s = q_len
+ output = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = output.view(bsz, q_len, -1)
+ else:
+ qkv = qkv.reshape(bsz, q_len, -1)
+ qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
+ qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
+ output = pad_input(output_unpad, indices, bsz, q_len)
+
+ return self.o_proj(output), None, past_key_value
+
+
+# Disable the transformation of the attention mask in LlamaModel as the flash attention
+# requires the attention mask to be the same as the key_padding_mask
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
+ # [bsz, seq_len]
+ return attention_mask
+
+
+def replace_llama_attn_with_flash_attn():
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
+ if cuda_major < 8:
+ warnings.warn(
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
+ )
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ _prepare_decoder_attention_mask
+ )
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
\ No newline at end of file
diff --git a/VL/llava/train/llama_xformers_attn_monkey_patch.py b/VL/llava/train/llama_xformers_attn_monkey_patch.py
new file mode 100644
index 00000000..a5c65da2
--- /dev/null
+++ b/VL/llava/train/llama_xformers_attn_monkey_patch.py
@@ -0,0 +1,129 @@
+"""
+Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
+"""
+
+import logging
+import math
+from typing import Optional, Tuple
+
+import torch
+import transformers.models.llama.modeling_llama
+from torch import nn
+
+try:
+ import xformers.ops
+except ImportError:
+ logging.error("xformers not found! Please install it before trying to use it.")
+
+
+def replace_llama_attn_with_xformers_attn():
+ transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
+
+
+def xformers_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # pylint: disable=duplicate-code
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ (
+ query_states,
+ key_states,
+ ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # We only apply xformers optimizations if we don't need to output the whole attention matrix
+ if not output_attentions:
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
+ # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
+ if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(
+ query_states, key_states, value_states, attn_bias=None
+ )
+ else:
+ # input and output should be of form (bsz, q_len, num_heads, head_dim)
+ attn_output = xformers.ops.memory_efficient_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_bias=xformers.ops.LowerTriangularMask(),
+ )
+ attn_weights = None
+ else:
+ attn_weights = torch.matmul(
+ query_states, key_states.transpose(2, 3)
+ ) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(
+ attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
+ )
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(
+ attn_weights, dim=-1, dtype=torch.float32
+ ).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2)
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+ return attn_output, attn_weights, past_key_value
\ No newline at end of file
diff --git a/VL/llava/train/llava_trainer.py b/VL/llava/train/llava_trainer.py
new file mode 100644
index 00000000..4ee10ede
--- /dev/null
+++ b/VL/llava/train/llava_trainer.py
@@ -0,0 +1,264 @@
+import os
+import torch
+
+from torch.utils.data import Sampler
+
+from transformers import Trainer
+from transformers.trainer import (
+ is_sagemaker_mp_enabled,
+ get_parameter_names,
+ has_length,
+ ALL_LAYERNORM_LAYERS,
+ ShardedDDPOption,
+ logger,
+)
+from typing import List, Optional
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ print(name, 'no ignore status')
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def split_to_even_chunks(indices, lengths, num_chunks):
+ """
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
+ """
+
+ if len(indices) % num_chunks != 0:
+ return [indices[i::num_chunks] for i in range(num_chunks)]
+
+ num_indices_per_chunk = len(indices) // num_chunks
+
+ chunks = [[] for _ in range(num_chunks)]
+ chunks_lengths = [0 for _ in range(num_chunks)]
+ for index in indices:
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
+ chunks[shortest_chunk].append(index)
+ chunks_lengths[shortest_chunk] += lengths[index]
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
+ chunks_lengths[shortest_chunk] = float("inf")
+
+ return chunks
+
+
+def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ assert all(l != 0 for l in lengths), "Should not have zero length."
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
+ # all samples are in the same modality
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
+
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
+ megabatch_size = world_size * batch_size
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
+
+ last_mm = mm_megabatches[-1]
+ last_lang = lang_megabatches[-1]
+ additional_batch = last_mm + last_lang
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
+ megabatches = [megabatches[i] for i in megabatch_indices]
+
+ if len(additional_batch) > 0:
+ megabatches.append(sorted(additional_batch))
+
+ return [i for megabatch in megabatches for i in megabatch]
+
+
+def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
+ indices = torch.randperm(len(lengths), generator=generator)
+ megabatch_size = world_size * batch_size
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
+
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
+
+
+class LengthGroupedSampler(Sampler):
+ r"""
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
+ keeping a bit of randomness.
+ """
+
+ def __init__(
+ self,
+ batch_size: int,
+ world_size: int,
+ lengths: Optional[List[int]] = None,
+ generator=None,
+ group_by_modality: bool = False,
+ ):
+ if lengths is None:
+ raise ValueError("Lengths must be provided.")
+
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.lengths = lengths
+ self.generator = generator
+ self.group_by_modality = group_by_modality
+
+ def __len__(self):
+ return len(self.lengths)
+
+ def __iter__(self):
+ if self.group_by_modality:
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ else:
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
+ return iter(indices)
+
+
+class LLaVATrainer(Trainer):
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ if self.args.group_by_modality_length:
+ lengths = self.train_dataset.modality_lengths
+ return LengthGroupedSampler(
+ self.args.train_batch_size,
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
+ lengths=lengths,
+ group_by_modality=True,
+ )
+ else:
+ return super()._get_train_sampler()
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ if is_sagemaker_mp_enabled():
+ return super().create_optimizer()
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ return super().create_optimizer()
+
+ opt_model = self.model
+
+ if self.optimizer is None:
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ if self.args.mm_projector_lr is not None:
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ "lr": self.args.mm_projector_lr,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ "lr": self.args.mm_projector_lr,
+ },
+ ]
+ else:
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ self.optimizer = OSS(
+ params=optimizer_grouped_parameters,
+ optim=optimizer_cls,
+ **optimizer_kwargs,
+ )
+ else:
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == "Adam8bit":
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ logger.info(f"skipped: {skipped/2**20}M params")
+
+ return self.optimizer
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+
+ # Only save Adapter
+ keys_to_match = ['mm_projector', 'vision_resampler']
+ if getattr(self.args, "use_im_start_end", False):
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
+
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
+ self.model.config.save_pretrained(output_dir)
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
+ else:
+ super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ pass
+ else:
+ super(LLaVATrainer, self)._save(output_dir, state_dict)
\ No newline at end of file
diff --git a/VL/llava/train/train.py b/VL/llava/train/train.py
new file mode 100644
index 00000000..31bdc240
--- /dev/null
+++ b/VL/llava/train/train.py
@@ -0,0 +1,916 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import copy
+from dataclasses import dataclass, field
+import json
+import logging
+import pathlib
+from typing import Dict, Optional, Sequence, List
+
+import torch
+
+import transformers
+
+from llava.model.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
+from torch.utils.data import Dataset
+from llava.train.llava_trainer import LLaVATrainer
+
+from llava import conversation as conversation_lib
+from llava.model import *
+from llava.mm_utils import tokenizer_image_token
+
+from PIL import Image
+
+local_rank = None
+
+
+def rank0_print(*args):
+ if local_rank == 0:
+ print(*args)
+
+
+@dataclass
+class ModelArguments:
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
+ freeze_backbone: bool = field(default=False)
+ tune_mm_mlp_adapter: bool = field(default=False)
+ vision_tower: Optional[str] = field(default=None)
+ mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
+ mm_projector_type: Optional[str] = field(default='linear')
+ mm_use_im_start_end: bool = field(default=False)
+ mm_use_im_patch_token: bool = field(default=True)
+ mm_vision_select_feature: Optional[str] = field(default="patch")
+
+
+@dataclass
+class DataArguments:
+ data_path: str = field(default=None,
+ metadata={"help": "Path to the training data."})
+ lazy_preprocess: bool = False
+ is_multimodal: bool = False
+ image_folder: Optional[str] = field(default=None)
+ image_aspect_ratio: str = 'square'
+
+
+@dataclass
+class TrainingArguments(transformers.TrainingArguments):
+ cache_dir: Optional[str] = field(default=None)
+ optim: str = field(default="adamw_torch")
+ remove_unused_columns: bool = field(default=False)
+ freeze_mm_mlp_adapter: bool = field(default=False)
+ mpt_attn_impl: Optional[str] = field(default="triton")
+ model_max_length: int = field(
+ default=512,
+ metadata={
+ "help":
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ },
+ )
+ double_quant: bool = field(
+ default=True,
+ metadata={"help": "Compress the quantization statistics through double quantization."}
+ )
+ quant_type: str = field(
+ default="nf4",
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
+ )
+ bits: int = field(
+ default=16,
+ metadata={"help": "How many bits to use."}
+ )
+ lora_enable: bool = False
+ lora_r: int = 64
+ lora_alpha: int = 16
+ lora_dropout: float = 0.05
+ lora_weight_path: str = ""
+ lora_bias: str = "none"
+ mm_projector_lr: Optional[float] = None
+ group_by_modality_length: bool = field(default=False)
+
+
+def maybe_zero_3(param, ignore_status=False, name=None):
+ from deepspeed import zero
+ from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
+ if hasattr(param, "ds_id"):
+ if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
+ if not ignore_status:
+ logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
+ with zero.GatheredParameters([param]):
+ param = param.data.detach().cpu().clone()
+ else:
+ param = param.detach().cpu().clone()
+ return param
+
+
+# Borrowed from peft.utils.get_peft_model_state_dict
+def get_peft_state_maybe_zero_3(named_params, bias):
+ if bias == "none":
+ to_return = {k: t for k, t in named_params if "lora_" in k}
+ elif bias == "all":
+ to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
+ elif bias == "lora_only":
+ to_return = {}
+ maybe_lora_bias = {}
+ lora_bias_names = set()
+ for k, t in named_params:
+ if "lora_" in k:
+ to_return[k] = t
+ bias_name = k.split("lora_")[0] + "bias"
+ lora_bias_names.add(bias_name)
+ elif "bias" in k:
+ maybe_lora_bias[k] = t
+ for k, t in maybe_lora_bias:
+ if bias_name in lora_bias_names:
+ to_return[bias_name] = t
+ else:
+ raise NotImplementedError
+ to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
+ return to_return
+
+
+def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
+ to_return = {k: t for k, t in named_params if "lora_" not in k}
+ if require_grad_only:
+ to_return = {k: t for k, t in to_return.items() if t.requires_grad}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
+ to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
+ return to_return
+
+
+def find_all_linear_names(model):
+ cls = torch.nn.Linear
+ lora_module_names = set()
+ multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
+ for name, module in model.named_modules():
+ if any(mm_keyword in name for mm_keyword in multimodal_keywords):
+ continue
+ if isinstance(module, cls):
+ names = name.split('.')
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
+
+ if 'lm_head' in lora_module_names: # needed for 16-bit
+ lora_module_names.remove('lm_head')
+ return list(lora_module_names)
+
+
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
+ output_dir: str):
+ """Collects the state dict and dump to disk."""
+
+ if getattr(trainer.args, "tune_mm_mlp_adapter", False):
+ # Only save Adapter
+ keys_to_match = ['mm_projector']
+ if getattr(trainer.args, "use_im_start_end", False):
+ keys_to_match.extend(['embed_tokens', 'embed_in'])
+
+ weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match)
+ trainer.model.config.save_pretrained(output_dir)
+
+ current_folder = output_dir.split('/')[-1]
+ parent_folder = os.path.dirname(output_dir)
+ if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
+ if current_folder.startswith('checkpoint-'):
+ mm_projector_folder = os.path.join(parent_folder, "mm_projector")
+ os.makedirs(mm_projector_folder, exist_ok=True)
+ torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
+ else:
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
+ return
+
+ if trainer.deepspeed:
+ torch.cuda.synchronize()
+ trainer.save_model(output_dir)
+ return
+
+ state_dict = trainer.model.state_dict()
+ if trainer.args.should_save:
+ cpu_state_dict = {
+ key: value.cpu()
+ for key, value in state_dict.items()
+ }
+ del state_dict
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
+
+
+def smart_tokenizer_and_embedding_resize(
+ special_tokens_dict: Dict,
+ tokenizer: transformers.PreTrainedTokenizer,
+ model: transformers.PreTrainedModel,
+):
+ """Resize tokenizer and embedding.
+
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
+ """
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
+ model.resize_token_embeddings(len(tokenizer))
+
+ if num_new_tokens > 0:
+ input_embeddings = model.get_input_embeddings().weight.data
+ output_embeddings = model.get_output_embeddings().weight.data
+
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
+ dim=0, keepdim=True)
+
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
+
+
+def _tokenize_fn(strings: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+ """Tokenize a list of strings."""
+ tokenized_list = [
+ tokenizer(
+ text,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ) for text in strings
+ ]
+ input_ids = labels = [
+ tokenized.input_ids[0] for tokenized in tokenized_list
+ ]
+ input_ids_lens = labels_lens = [
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
+ for tokenized in tokenized_list
+ ]
+ return dict(
+ input_ids=input_ids,
+ labels=labels,
+ input_ids_lens=input_ids_lens,
+ labels_lens=labels_lens,
+ )
+
+
+def _mask_targets(target, tokenized_lens, speakers):
+ # cur_idx = 0
+ cur_idx = tokenized_lens[0]
+ tokenized_lens = tokenized_lens[1:]
+ target[:cur_idx] = IGNORE_INDEX
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
+ if speaker == "human":
+ target[cur_idx + 2:cur_idx + tokenized_len] = IGNORE_INDEX
+ cur_idx += tokenized_len
+
+
+def _add_speaker_and_signal(header, source, get_conversation=True):
+ """Add speaker and start/end signal on each round."""
+ BEGIN_SIGNAL = "### "
+ END_SIGNAL = "\n"
+ conversation = header
+ for sentence in source:
+ from_str = sentence["from"]
+ if from_str.lower() == "human":
+ from_str = conversation_lib.default_conversation.roles[0]
+ elif from_str.lower() == "gpt":
+ from_str = conversation_lib.default_conversation.roles[1]
+ else:
+ from_str = 'unknown'
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
+ sentence["value"] + END_SIGNAL)
+ if get_conversation:
+ conversation += sentence["value"]
+ conversation += BEGIN_SIGNAL
+ return conversation
+
+
+def preprocess_multimodal(
+ sources: Sequence[str],
+ data_args: DataArguments
+) -> Dict:
+ is_multimodal = data_args.is_multimodal
+ if not is_multimodal:
+ return sources
+
+ for source in sources:
+ for sentence in source:
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
+ sentence['value'] = sentence['value'].strip()
+ if "mmtag" in conversation_lib.default_conversation.version:
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN,
+ '' + DEFAULT_IMAGE_TOKEN + '')
+ replace_token = DEFAULT_IMAGE_TOKEN
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+
+ return sources
+
+
+def preprocess_llama_2(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack(
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2
+
+ # Mask targets
+ sep = "[/INST] "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_v1(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+
+ if has_image:
+ input_ids = torch.stack(
+ [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
+ else:
+ input_ids = tokenizer(
+ conversations,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ ).input_ids
+
+ targets = input_ids.clone()
+
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+
+ if has_image:
+ round_len = len(tokenizer_image_token(rou, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
+ else:
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_mpt(
+ sources,
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ conv = conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ # Apply prompt templates
+ conversations = []
+ for i, source in enumerate(sources):
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ # Tokenize conversations
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations],
+ dim=0)
+ targets = input_ids.clone()
+ assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
+
+ # Mask targets
+ sep = conv.sep + conv.roles[1]
+ for conversation, target in zip(conversations, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep)
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
+ for conv_idx in range(3, len(rounds), 2):
+ re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx + 2])) # user + gpt
+ cur_len = 0
+ target[:cur_len] = IGNORE_INDEX
+ for i, rou in enumerate(re_rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ if len(parts) != 2:
+ break
+ parts[0] += sep
+ round_len = len(tokenizer_image_token(rou, tokenizer)) + len(tokenizer_image_token(conv.sep, tokenizer))
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer))
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_INDEX
+
+ if cur_len < tokenizer.model_max_length:
+ if cur_len != total_len:
+ target[:] = IGNORE_INDEX
+ print(
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
+ f" (ignored)"
+ )
+
+ return dict(
+ input_ids=input_ids,
+ labels=targets,
+ )
+
+
+def preprocess_plain(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+) -> Dict:
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ assert len(source) == 2
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
+ conversations.append(conversation)
+ # tokenize conversations
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
+ target[:tokenized_len] = IGNORE_INDEX
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+def preprocess(
+ sources: Sequence[str],
+ tokenizer: transformers.PreTrainedTokenizer,
+ has_image: bool = False
+) -> Dict:
+ """
+ Given a list of sources, each is a conversation list. This transform:
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
+ 2. Concatenate conversations together;
+ 3. Tokenize the concatenated conversation;
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
+ """
+ # if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
+ # return preprocess_plain(sources, tokenizer)
+ # if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.LLAMA_2:
+ # return preprocess_llama_2(sources, tokenizer, has_image=has_image)
+ # if conversation_lib.default_conversation.version.startswith("v1"):
+ # return preprocess_v1(sources, tokenizer, has_image=has_image)
+ # if conversation_lib.default_conversation.version == "mpt":
+ # return preprocess_mpt(sources, tokenizer)
+ # add end signal and concatenate together
+ conversations = []
+ for source in sources:
+ header = f"{conversation_lib.default_conversation.system}\n\n"
+ conversation = _add_speaker_and_signal(header, source)
+ conversations.append(conversation)
+
+ # tokenize conversations
+ def get_tokenize_len(prompts):
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
+
+ if has_image:
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
+ else:
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
+ input_ids = conversations_tokenized["input_ids"]
+
+ targets = copy.deepcopy(input_ids)
+ for target, source in zip(targets, sources):
+ if has_image:
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
+ else:
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"]
+ speakers = [sentence["from"] for sentence in source]
+ _mask_targets(target, tokenized_lens, speakers)
+
+ return dict(input_ids=input_ids, labels=targets)
+
+
+class LazySupervisedDataset(Dataset):
+ """Dataset for supervised fine-tuning."""
+
+ def __init__(self, data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ data_args: DataArguments):
+ super(LazySupervisedDataset, self).__init__()
+ list_data_dict = json.load(open(data_path, "r"))
+
+ rank0_print("Formatting inputs...Skip in lazy mode")
+ self.tokenizer = tokenizer
+ self.list_data_dict = list_data_dict
+ self.data_args = data_args
+
+ def __len__(self):
+ return len(self.list_data_dict)
+
+ @property
+ def lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ img_tokens = 128 if 'image' in sample else 0
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
+ return length_list
+
+ @property
+ def modality_lengths(self):
+ length_list = []
+ for sample in self.list_data_dict:
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
+ cur_len = cur_len if 'image' in sample else -cur_len
+ length_list.append(cur_len)
+ return length_list
+
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
+ sources = self.list_data_dict[i]
+ if isinstance(i, int):
+ sources = [sources]
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
+ if 'image' in sources[0]:
+ image_file = self.list_data_dict[i]['image']
+ image_folder = self.data_args.image_folder
+ processor = self.data_args.image_processor
+ image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
+ if self.data_args.image_aspect_ratio == 'pad':
+ def expand2square(pil_img, background_color):
+ width, height = pil_img.size
+ if width == height:
+ return pil_img
+ elif width > height:
+ result = Image.new(pil_img.mode, (width, width), background_color)
+ result.paste(pil_img, (0, (width - height) // 2))
+ return result
+ else:
+ result = Image.new(pil_img.mode, (height, height), background_color)
+ result.paste(pil_img, ((height - width) // 2, 0))
+ return result
+
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ else:
+ image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+ sources = preprocess_multimodal(
+ copy.deepcopy([e["conversations"] for e in sources]),
+ self.data_args)
+ else:
+ sources = copy.deepcopy([e["conversations"] for e in sources])
+ data_dict = preprocess(
+ sources,
+ self.tokenizer,
+ has_image=('image' in self.list_data_dict[i]))
+ if isinstance(i, int):
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
+ labels=data_dict["labels"][0])
+
+ # image exist in the data
+ if 'image' in self.list_data_dict[i]:
+ data_dict['image'] = image
+ elif self.data_args.is_multimodal:
+ # image does not exist in the data, but the model is multimodal
+ crop_size = self.data_args.image_processor.crop_size
+ data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
+ return data_dict
+
+
+@dataclass
+class DataCollatorForSupervisedDataset(object):
+ """Collate examples for supervised fine-tuning."""
+
+ tokenizer: transformers.PreTrainedTokenizer
+
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
+ input_ids, labels = tuple([instance[key] for instance in instances]
+ for key in ("input_ids", "labels"))
+ input_ids = torch.nn.utils.rnn.pad_sequence(
+ input_ids,
+ batch_first=True,
+ padding_value=self.tokenizer.pad_token_id)
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
+ batch_first=True,
+ padding_value=IGNORE_INDEX)
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
+ labels = labels[:, :self.tokenizer.model_max_length]
+ batch = dict(
+ input_ids=input_ids,
+ labels=labels,
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
+ )
+
+ if 'image' in instances[0]:
+ images = [instance['image'] for instance in instances]
+ if all(x is not None and x.shape == images[0].shape for x in images):
+ batch['images'] = torch.stack(images)
+ else:
+ batch['images'] = images
+
+ return batch
+
+
+def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
+ data_args) -> Dict:
+ """Make dataset and collator for supervised fine-tuning."""
+ train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
+ data_path=data_args.data_path,
+ data_args=data_args)
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
+ return dict(train_dataset=train_dataset,
+ eval_dataset=None,
+ data_collator=data_collator)
+
+
+def train():
+ global local_rank
+
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments))
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+ local_rank = training_args.local_rank
+ compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+
+ bnb_model_from_pretrained_args = {}
+ if training_args.bits in [4, 8]:
+ from transformers import BitsAndBytesConfig
+ bnb_model_from_pretrained_args.update(dict(
+ device_map={"": training_args.device},
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ quantization_config=BitsAndBytesConfig(
+ load_in_4bit=training_args.bits == 4,
+ load_in_8bit=training_args.bits == 8,
+ llm_int8_skip_modules=["mm_projector"],
+ llm_int8_threshold=6.0,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=compute_dtype,
+ bnb_4bit_use_double_quant=training_args.double_quant,
+ bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
+ )
+ ))
+
+ model = LlavaLlamaForCausalLM.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ **bnb_model_from_pretrained_args
+ )
+
+ model.config.use_cache = False
+
+ if model_args.freeze_backbone:
+ model.model.requires_grad_(False)
+
+ if training_args.bits in [4, 8]:
+ from peft import prepare_model_for_kbit_training
+ model.config.torch_dtype = (
+ torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
+ model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
+
+ if training_args.gradient_checkpointing:
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ if training_args.lora_enable:
+ from peft import LoraConfig, get_peft_model
+ lora_config = LoraConfig(
+ r=training_args.lora_r,
+ lora_alpha=training_args.lora_alpha,
+ target_modules=find_all_linear_names(model),
+ lora_dropout=training_args.lora_dropout,
+ bias=training_args.lora_bias,
+ task_type="CAUSAL_LM",
+ )
+ if training_args.bits == 16:
+ if training_args.bf16:
+ model.to(torch.bfloat16)
+ if training_args.fp16:
+ model.to(torch.float16)
+ rank0_print("Adding LoRA adapters...")
+ model = get_peft_model(model, lora_config)
+
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ cache_dir=training_args.cache_dir,
+ model_max_length=training_args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+
+ if model_args.vision_tower is not None:
+ model.get_model().initialize_vision_modules(
+ model_args=model_args
+ )
+
+ vision_tower = model.get_vision_tower()
+ vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
+
+ data_args.image_processor = vision_tower.image_processor
+ data_args.is_multimodal = True
+
+ model.config.image_aspect_ratio = data_args.image_aspect_ratio
+ model.config.tokenizer_padding_side = tokenizer.padding_side
+ model.config.tokenizer_model_max_length = tokenizer.model_max_length
+
+ model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
+ if model_args.tune_mm_mlp_adapter:
+ model.requires_grad_(False)
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = True
+
+ model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter
+ if training_args.freeze_mm_mlp_adapter:
+ for p in model.get_model().mm_projector.parameters():
+ p.requires_grad = False
+
+ if training_args.bits in [4, 8]:
+ model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device)
+
+ model.config.mm_projector_lr = training_args.mm_projector_lr
+ model.config.mm_use_im_patch_token = model_args.mm_use_im_patch_token
+ model.initialize_vision_tokenizer(model_args, tokenizer=tokenizer)
+
+ if training_args.bits in [4, 8]:
+ from peft.tuners.lora import LoraLayer
+ for name, module in model.named_modules():
+ if isinstance(module, LoraLayer):
+ if training_args.bf16:
+ module = module.to(torch.bfloat16)
+ if 'norm' in name:
+ module = module.to(torch.float32)
+ if 'lm_head' in name or 'embed_tokens' in name:
+ if hasattr(module, 'weight'):
+ if training_args.bf16 and module.weight.dtype == torch.float32:
+ module = module.to(torch.bfloat16)
+
+ data_module = make_supervised_data_module(tokenizer=tokenizer,
+ data_args=data_args)
+ trainer = LLaVATrainer(model=model,
+ tokenizer=tokenizer,
+ args=training_args,
+ **data_module)
+
+ if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
+ trainer.train(resume_from_checkpoint=True)
+ else:
+ trainer.train()
+ trainer.save_state()
+
+ model.config.use_cache = True
+
+ if training_args.lora_enable:
+ state_dict = get_peft_state_maybe_zero_3(
+ model.named_parameters(), training_args.lora_bias
+ )
+ non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
+ model.named_parameters()
+ )
+ if training_args.local_rank == 0 or training_args.local_rank == -1:
+ model.config.save_pretrained(training_args.output_dir)
+ model.save_pretrained(training_args.output_dir, state_dict=state_dict)
+ torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
+ else:
+ safe_save_model_for_hf_trainer(trainer=trainer,
+ output_dir=training_args.output_dir)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/VL/llava/train/train_mem.py b/VL/llava/train/train_mem.py
new file mode 100644
index 00000000..2487d317
--- /dev/null
+++ b/VL/llava/train/train_mem.py
@@ -0,0 +1,13 @@
+# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
+# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
+# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
+
+# Need to call this before importing transformers.
+from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
+
+replace_llama_attn_with_flash_attn()
+
+from llava.train.train import train
+
+if __name__ == "__main__":
+ train()
diff --git a/VL/llava/train/train_xformers.py b/VL/llava/train/train_xformers.py
new file mode 100644
index 00000000..c0ef8212
--- /dev/null
+++ b/VL/llava/train/train_xformers.py
@@ -0,0 +1,13 @@
+# Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
+
+# Need to call this before importing transformers.
+from llava.train.llama_xformers_attn_monkey_patch import (
+ replace_llama_attn_with_xformers_attn,
+)
+
+replace_llama_attn_with_xformers_attn()
+
+from llava.train.train import train
+
+if __name__ == "__main__":
+ train()
\ No newline at end of file
diff --git a/VL/requirements.txt b/VL/requirements.txt
index 33601b7e..0aa33f1c 100644
--- a/VL/requirements.txt
+++ b/VL/requirements.txt
@@ -1,9 +1,13 @@
-transformers>=4.36.2
-gradio>=4.13.0
+transformers==4.34.0
+gradio
protobuf>=4.25.1
torch>=2.0.1
torchvision
accelerate
sentencepiece
deepspeed
-datasets
\ No newline at end of file
+datasets
+flash-attn
+bitsandbytes
+peft
+wandb
\ No newline at end of file
diff --git a/VL/scripts/finetune.sh b/VL/scripts/finetune.sh
new file mode 100644
index 00000000..bdd84019
--- /dev/null
+++ b/VL/scripts/finetune.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+PYTHONPATH=../../:$PYTHONPATH \
+deepspeed --include localhost:6,7 --master_port 1234 llava/train/train_mem.py \
+ --deepspeed ./scripts/zero2.json \
+ --model_name_or_path /path/to/Yi-VL-model \
+ --data_path /path/to/dataset \
+ --image_folder /path/to/image/folder \
+ --vision_tower /path/to/vit/model \
+ --output_dir /path/to/output/model \
+ --mm_vision_select_layer -2 \
+ --mm_use_im_start_end False \
+ --mm_use_im_patch_token False \
+ --bf16 True \
+ --num_train_epochs 10 \
+ --per_device_train_batch_size 1 \
+ --per_device_eval_batch_size 1 \
+ --gradient_accumulation_steps 8 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 200 \
+ --save_total_limit 3 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --lazy_preprocess True \
+ --dataloader_num_workers 4 \
+ --report_to wandb
\ No newline at end of file
diff --git a/VL/scripts/finetune_lora.sh b/VL/scripts/finetune_lora.sh
new file mode 100644
index 00000000..16a3bed5
--- /dev/null
+++ b/VL/scripts/finetune_lora.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+
+PYTHONPATH=../../:$PYTHONPATH \
+deepspeed --include localhost:6,7 --master_port 1234 llava/train/train_mem.py \
+ --deepspeed ./scripts/zero2.json \
+ --lora_enable True \
+ --model_name_or_path /path/to/Yi-VL-model \
+ --data_path /path/to/dataset \
+ --image_folder /path/to/image/folder \
+ --vision_tower /path/to/vit/model \
+ --output_dir /path/to/output \
+ --mm_vision_select_layer -2 \
+ --mm_use_im_start_end False \
+ --mm_use_im_patch_token False \
+ --bf16 True \
+ --num_train_epochs 10 \
+ --per_device_train_batch_size 4 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 8 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 200 \
+ --save_total_limit 3 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --lazy_preprocess True \
+ --dataloader_num_workers 4 \
+ --report_to wandb
\ No newline at end of file
diff --git a/VL/scripts/finetune_qlora.sh b/VL/scripts/finetune_qlora.sh
new file mode 100644
index 00000000..95ae40c8
--- /dev/null
+++ b/VL/scripts/finetune_qlora.sh
@@ -0,0 +1,35 @@
+#!/bin/bash
+
+PYTHONPATH=../../:$PYTHONPATH \
+deepspeed --include localhost:0,1,2,3 --master_port 1234 llava/train/train_mem.py \
+ --deepspeed ./scripts/zero2.json \
+ --lora_enable True \
+ --bits 4 \
+ --model_name_or_path /path/to/Yi-VL-model \
+ --data_path /path/to/dataset \
+ --image_folder /path/to/image/folder \
+ --vision_tower /path/to/vit/model \
+ --output_dir /path/to/output \
+ --mm_vision_select_layer -2 \
+ --mm_use_im_start_end False \
+ --mm_use_im_patch_token False \
+ --bf16 True \
+ --num_train_epochs 10 \
+ --per_device_train_batch_size 4 \
+ --per_device_eval_batch_size 4 \
+ --gradient_accumulation_steps 8 \
+ --evaluation_strategy "no" \
+ --save_strategy "steps" \
+ --save_steps 200 \
+ --save_total_limit 3 \
+ --learning_rate 2e-5 \
+ --weight_decay 0. \
+ --warmup_ratio 0.03 \
+ --lr_scheduler_type "cosine" \
+ --logging_steps 1 \
+ --tf32 True \
+ --model_max_length 2048 \
+ --gradient_checkpointing True \
+ --lazy_preprocess True \
+ --dataloader_num_workers 4 \
+ --report_to wandb
diff --git a/VL/scripts/merge_lora_weights.py b/VL/scripts/merge_lora_weights.py
new file mode 100644
index 00000000..594ba558
--- /dev/null
+++ b/VL/scripts/merge_lora_weights.py
@@ -0,0 +1,22 @@
+import argparse
+from llava.mm_utils import load_pretrained_model
+
+
+def merge_lora(args):
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path,
+ lora_path=args.lora_path,
+ device_map='cpu')
+
+ model.save_pretrained(args.save_model_path)
+ tokenizer.save_pretrained(args.save_model_path)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model-path", type=str, required=True)
+ parser.add_argument("--lora-path", type=str, required=True)
+ parser.add_argument("--save-model-path", type=str, required=True)
+
+ args = parser.parse_args()
+
+ merge_lora(args)
diff --git a/VL/scripts/zero2.json b/VL/scripts/zero2.json
new file mode 100644
index 00000000..c95ebefe
--- /dev/null
+++ b/VL/scripts/zero2.json
@@ -0,0 +1,23 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "train_micro_batch_size_per_gpu": "auto",
+ "train_batch_size": "auto",
+ "gradient_accumulation_steps": "auto",
+ "zero_optimization": {
+ "stage": 2,
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto"
+ }
+}
\ No newline at end of file
diff --git a/VL/scripts/zero3.json b/VL/scripts/zero3.json
new file mode 100644
index 00000000..6917317a
--- /dev/null
+++ b/VL/scripts/zero3.json
@@ -0,0 +1,28 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "train_micro_batch_size_per_gpu": "auto",
+ "train_batch_size": "auto",
+ "gradient_accumulation_steps": "auto",
+ "zero_optimization": {
+ "stage": 3,
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ }
+}
\ No newline at end of file
diff --git a/VL/scripts/zero3_offload.json b/VL/scripts/zero3_offload.json
new file mode 100644
index 00000000..e0a54c2c
--- /dev/null
+++ b/VL/scripts/zero3_offload.json
@@ -0,0 +1,56 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": "auto",
+ "stage3_prefetch_bucket_size": "auto",
+ "stage3_param_persistence_threshold": "auto",
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "gather_16bit_weights_on_model_save": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "steps_per_print": 1e5,
+ "wall_clock_breakdown": false
+}
\ No newline at end of file