Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/sphinx_doc/assets/mix_vlm_reward.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/grpo_vlm/vlm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ buffer:
taskset:
name: geometry3k
storage_type: file
path: hiyouga/geometry3k
path: ${oc.env:TRINITY_TASKSET_PATH,hiyouga/geometry3k}
subset_name: 'default'
split: 'train'
format:
Expand Down
2 changes: 1 addition & 1 deletion examples/mix_chord/mix_chord.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ buffer:
name: SFT_data
storage_type: file
schema_type: sft
path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts}
path: ${oc.env:TRINITY_SFT_DATASET_PATH}
split: 'train'
format:
prompt_type: messages
Expand Down
2 changes: 1 addition & 1 deletion examples/mix_math/mix_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ buffer:
name: math_sft
storage_type: file
schema_type: sft
path: ${oc.env:TRINITY_SFT_DATASET_PATH,open-r1/Mixture-of-Thoughts}
path: ${oc.env:TRINITY_SFT_DATASET_PATH}
split: 'train'
format:
prompt_type: messages
Expand Down
36 changes: 36 additions & 0 deletions examples/mix_vlm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# MIX algorithm with VLM

This is an example of using the [MIX](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md) algorithm with Qwen2.5-VL-3B-Instruct model.

> [!NOTE]
> This feature is experimental and will be subject to change in future releases.

The specific requirements are:

```yaml
vllm>=0.9.1,<0.10.0
transformers<4.53.0
qwen_vl_utils
```

## Prepare the SFT Dataset
We use the [geometry3k](https://huggingface.co/datasets/hiyouga/geometry3k) dataset for training; we generate the [SFT dataset](https://huggingface.co/datasets/datajuicer/geometry_sft) by prompting Qwen2.5-VL-32B-Instruct model on the validation set. Note that this dataset only showcases the format of SFT data in this example, as shown below:
```json
{
"problem": "<image>Find $x$ so that $m || n$.",
"response": "To determine the value of $ x $ ... Answer:\n\\[\n\\boxed{63}\n\\]",
"images": [<image>]
}
```

The config file is located in [`mix_vlm.yaml`](mix_vlm.yaml). To get better performance, feel free to try out different algorithm hyperparameters!

## Run the Example

Run the following command to start the training:
```bash
trinity run --config examples/mix_vlm/mix_vlm.yaml
```

The reward curve is shown below:
![](../../docs/sphinx_doc/assets/mix_vlm_reward.png)
94 changes: 94 additions & 0 deletions examples/mix_vlm/mix_vlm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
project: "Trinity-RFT"
name: "mix_vlm"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: mix_chord
repeat_times: 8
optimizer:
lr: 1e-6
kl_loss_fn_args:
kl_coef: 0.0
entropy_loss_fn: mix
sample_strategy_args:
expert_data_ratio: 0.20
policy_loss_fn_args:
mu_warmup_steps: 200
mu_decay_steps: 400
mu_peak: 0.1
mu_valley: 0.1
enable_phi_function: false
clip_range: 0.2
sft_loss_agg_mode: "token-mean"
use_dynamic_bsz: true
ppo_mini_batch_size: 320 # 320 = 256 + 64
ppo_micro_batch_size_per_gpu: 4
ngpus_trainer: 4
train_batch_size_expert: 64
train_batch_size_usual: 256 # 32 batchsize * 8 repeat times
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_response_tokens: 10240
max_model_len: 11264
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 4
batch_size: 32
train_batch_size: 320
explorer_input:
taskset:
name: geometry3k
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH,hiyouga/geometry3k}
subset_name: 'default'
split: 'train'
format:
prompt_key: 'problem'
response_key: 'answer'
image_key: 'images'
rollout_args:
temperature: 1.0
logprobs: 0
workflow_args:
with_think: true
eval_tasksets: [] # you can add your own eval tasksets here
default_workflow_type: 'simple_mm_workflow'
default_reward_fn_type: 'math_boxed_reward'
trainer_input:
experience_buffer:
name: experience_buffer
storage_type: queue
auxiliary_buffers:
sft_dataset:
total_epochs: 25
name: geometry_sft
storage_type: file
schema_type: sft
path: datajuicer/geometry_sft
split: 'train'
format:
prompt_type: plaintext
prompt_key: 'problem'
response_key: 'response'
image_key: 'images'
explorer:
eval_interval: 10
runner_per_model: 8
rollout_model:
engine_num: 4
tensor_parallel_size: 1
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'nccl'
sync_interval: 1
sync_timeout: 1200
trainer:
save_interval: 50
grad_clip: 1.0
use_dynamic_bsz: true
max_token_len_per_gpu: 11264
ulysses_sequence_parallel_size: 2