- 🔥2024.5.30: We release the code!
Research and development on preference learning of Vision-Language Models (VLM, LVLM or MLLM) is difficult because there is currently no unified model architecture in the VLM community. The implementations of the state-of-the-art VLMs like LLaVA, Qwen-VL and internlm-xcomposer, are also in different styles, leading it hard to include them in a single training framework. VL-RLHF provides a perfect solution to abstract VLMs in a framework, and its features includes
- Supports popular VLMs : LLaVA-1.5, InstructBLIP, LLaVA-Next, Qwen-VL, InternLM-XComposer2, etc.
- Supports popular fine-tuning methods : SFT, DPO, KTO, etc.
- Evaluation on popular benchmarks : MME, MMBench, SEEDBench, MMVet, MMMU, etc.
- Easy to expand : Customize your own dataset and model with few code.
- Supervised Fine-tuning (SFT)
- Direct Preference Optimization (DPO)
- Dense Direct Preference Optimization (DDPO)
- Proximal Policy Optimization (PPO) (coming soon)
- Kahneman-Tversky Optimization (KTO) (coming soon)
To install from source code (convenient for running the training and evaluation scripts), please run the following commands:
git clone https://github.com/TideDra/VL-RLHF.git
cd VL-RLHF
pip install -e .
We recommend to install FlashAttention for effective training and inference:
pip install flash-attn==2.5.8 --no-build-isolation
You can run the following command to launch DPO training of QwenVL-Chat using VLFeedback dataset:
#model weights should exist in ckpts/Qwen-VL-Chat
bash scripts/dpo_qwenvl.sh
Or run the python file directly:
accelerate launch --config_file accelerate_config/zero2.yaml --num_processes 8 \
src/vlrlhf/dpo.py \
--model_name_or_path ckpts/Qwen-VL-Chat \
--output_dir ckpts/Qwen-VL-Chat-dpo/ \
--dataset_name VLFeedback \
--data_ratio 1.0 \
--freeze_vision_tower True \
--use_flash_attention_2 False \
--use_lora True \
--lora_r 64 \
--lora_alpha 16 \
--lora_dropout 0.05 \
--lora_target_modules auto \
--lora_bias "none" \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--num_train_epochs 1 \
--adam_beta1 0.9 \
--adam_beta2 0.98 \
--adam_epsilon 1e-6 \
--learning_rate 1e-5 \
--weight_decay 0.05 \
--warmup_ratio 0.1 \
--lr_scheduler_type "cosine" \
--gradient_checkpointing True \
--bf16 True \
--tf32 True \
--remove_unused_columns False \
--beta 0.1 \
--max_length 1024 \
--max_prompt_length 512 \
--max_target_length 512 \
--eval_strategy "steps" \
--eval_steps 200 \
--save_strategy "steps" \
--save_steps 100 \
--save_total_limit 10 \
--logging_first_step False \
--logging_steps 10 \
--report_to wandb \
--run_name "bs256_lr1e-5" \
--project_name "VL-RLHF" \
--group_name "Qwen-VL-Chat-dpo"
To train other models with other methods, you can refer to the related scripts in scripts/
directory.
Please refer to arguments.md for detailed explanation of each arguments used in the scripts.
VL-RLHF uses three arguments when processing the given dataset, which can be found in all the example training scripts. Please make sure they are properly set in the script before running it:
--dataset_name
The name of the dataset. It can bevlfeedback_paired
for the VLFeedback dataset,rlhfv
for the RLHF-V dataset,vlquery_json
for customized multimodal conversation data stored in json format,plain_dpo
for customized multimodal comparison data stored in json format.--data_path
Path to the json file. Only needed for customized dataset. If you use VLFeedback or RLHF-V, these datasets will be automatically downloaded from huggingface and loaded via thedatasets
package.--image_root
Root directory of the images. Only needed for customized dataset. It will be joined with the image path of each sample in the json file.
For methods that need comparison data, e.g. DPO, DDPO, KTO(paired), please prepare your json data in the following format:
[
{
"image":"example.jpg",
"prompt":"Describe this image in detail.",
"chosen":"This is a cat.",
"rejected":"This is a dog."
},
...
]
And set --dataset_name
to plain_dpo
in the training command.
For SFT, please prepare your conversation data in the following format:
[
{
"image":"example.jpg",
"conversations":[
{
"from": "user",
"value": "<prompt>",
},
{
"from": "assistant",
"value": "<answer>",
},
...
]
},
...
]
And set --dataset_name
to vlquery_json
in the training command.
You can easily add your own model to VL-RLHF framework by implementing some APIs. Please refer to CustomizedModel.md
VL-RLHF supports to evaluate VLMs on popular multimodal benchmarks like MME, MMVet, Seedbench, MMBench and so on. Please refer to the Evaluation Guide for details.
For reference, we report the performance of some models before and after DPO training on VLFeedback .
Model | MMBench | MMVet | SEEDBench-Img | MMMU | MathVista |
---|---|---|---|---|---|
InternLM-Xcomposer2-VL-7b | 76.37 | 46.5 | 74.19 | 40.33 | 56.7 |
InternLM-Xcomposer2-VL-7b-DPO | 78.18 | 49.7 | 75.18 | 39.67 | 56.6 |
Qwen-VL-Chat | 56.53 | 48.5 | 59.63 | 35.67 | 35.6 |
Qwen-VL-Chat-DPO | 57.56 | 49.1 | 60.67 | 37.89 | 35.6 |
LLaVA-Next-Mistral-7b | 67.70 | 43.8 | 71.7 | 37.00 | 35.1 |
LLaVA-Next-Mistral-7b-DPO | 68.30 | 44.2 | 71.7 | 36.89 | 36.2 |
LLaVA-Next-Vicuna-7b | 62.71 | 38.2 | 68.17 | 34.00 | 31.3 |
LLaVA-Next-Vicuna-7b-DPO | 64.52 | 44.1 | 69.75 | 33.11 | 32.0 |
We would like to express our gratitude to the following projects:
If you find this work helpful, please consider to star🌟 this repo. Thanks for your support!
If you use VL-RLHF in your research, please use the following BibTeX entry.
@misc{vlrlhf,
title = {VL-RLHF: A RLHF Infrastructure for Vision-Language Model},
author = {Gongrui Zhang},
howpublished = {\url{https://github.com/TideDra/VL-RLHF}},
year = {2024}
}