Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/grpo_lora_gsm8k/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# GRPO with LoRA

This example shows the usage of LoRA on the GSM8K dataset.

## GRPO training
Compared with full model fine-tuning, Trinity-RFT enable LoRA by providing the `lora_configs` field as follows:

```yaml
project: "Trinity-RFT-gsm8k"
name: "qwen2.5-1.5B-gsm8k"
model:
lora_configs:
- name: lora
lora_rank: 32
lora_alpha: 32
synchronizer:
sync_method: 'checkpoint'
```

Note that the `lora_rank` and `lora_alpha` are hyperparameters that need to be tuned. For `lora_rank`, a very small value can lead to slower convergence or worse training performance, while a very large value can lead to memory and performance issues.

For now, we only support a single-lora training and synchronizing via `checkpoint`.

## Benchmark with LoRA
After training, we can evaluate the performance of checkpoints via the `bench` mode. Some key configurations are shown below:

```yaml
mode: bench
project: "Trinity-RFT-gsm8k" # same as training
name: "qwen2.5-1.5B-gsm8k" # same as training
model:
lora_configs: # same as training
- name: lora
lora_rank: 32
lora_alpha: 32
explorer:
rollout_model:
engine_num: 2 # ensure all gpus are used for benchmarking
tensor_parallel_size: 4
synchronizer:
sync_method: 'checkpoint'
```
82 changes: 82 additions & 0 deletions examples/grpo_lora_gsm8k/gsm8k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
project: "Trinity-RFT-gsm8k"
name: "qwen2.5-1.5B-gsm8k"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo
repeat_times: 8
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-1.5B-Instruct}
max_response_tokens: 1024
max_model_len: 1280
lora_configs:
- name: lora
lora_rank: 32
lora_alpha: 32
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 10
batch_size: 96
explorer_input:
taskset:
name: gsm8k
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'train'
format:
prompt_key: 'question'
response_key: 'answer'
rollout_args:
temperature: 1.0
eval_tasksets:
- name: gsm8k-eval
storage_type: file
path: 'openai/gsm8k'
subset_name: 'main'
split: 'test'
format:
prompt_key: 'question'
response_key: 'answer'
default_workflow_type: 'math_workflow'
trainer_input:
experience_buffer:
name: gsm8k_buffer
storage_type: queue
explorer:
eval_interval: 10
runner_per_model: 16
rollout_model:
engine_num: 1
tensor_parallel_size: 4
enable_prefix_caching: false
enforce_eager: true
dtype: bfloat16
seed: 42
synchronizer:
sync_method: 'checkpoint'
sync_interval: 1
sync_timeout: 1200
trainer:
trainer_type: 'verl'
save_interval: 100
trainer_config:
actor_rollout_ref:
model:
use_remove_padding: true
actor:
use_dynamic_bsz: true
ppo_max_token_len_per_gpu: 16384
ulysses_sequence_parallel_size: 1
optim:
lr: 1e-5
checkpoint:
load_contents:
- model
save_contents:
- model
ref:
log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size
20 changes: 18 additions & 2 deletions tests/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import ray
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

from trinity.common.config import Config, FormatConfig, StorageConfig, load_config
from trinity.common.config import (
Config,
FormatConfig,
LoRAConfig,
StorageConfig,
load_config,
)
from trinity.common.constants import (
CHECKPOINT_ROOT_DIR_ENV_VAR,
MODEL_PATH_ENV_VAR,
Expand Down Expand Up @@ -64,11 +70,15 @@ def get_vision_languge_model_path() -> str:
return path


def get_lora_config() -> LoRAConfig:
return LoRAConfig(name="lora", lora_rank=16, lora_alpha=16)


def get_unittest_dataset_config(
dataset_name: str = "countdown", split: str = "train"
) -> StorageConfig:
"""Countdown dataset with 17 samples."""
if dataset_name == "countdown" or dataset_name == "copy_countdown":
# Countdown dataset with 17 samples
return StorageConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "countdown"),
Expand All @@ -82,6 +92,7 @@ def get_unittest_dataset_config(
default_reward_fn_type="countdown_reward",
)
elif dataset_name in {"eval_short", "eval_long"}:
# Eval_short dataset with 2 samples, eval_long dataset with 8 samples
return StorageConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", dataset_name),
Expand All @@ -94,6 +105,7 @@ def get_unittest_dataset_config(
default_reward_fn_type="math_reward",
)
elif dataset_name == "gsm8k":
# GSM8K dataset with 16 samples
return StorageConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "gsm8k"),
Expand All @@ -106,6 +118,7 @@ def get_unittest_dataset_config(
default_reward_fn_type="math_reward",
)
elif dataset_name == "sft_for_gsm8k":
# SFT dataset with 8 samples
return StorageConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_for_gsm8k"),
Expand All @@ -118,6 +131,7 @@ def get_unittest_dataset_config(
),
)
elif dataset_name == "sft_with_tools":
# SFT_with_tools dataset with 4 samples
return StorageConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "sft_with_tools"),
Expand All @@ -130,6 +144,7 @@ def get_unittest_dataset_config(
),
)
elif dataset_name == "dpo":
# HumanLike DPO dataset with 17 samples
return StorageConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "human_like"),
Expand All @@ -142,6 +157,7 @@ def get_unittest_dataset_config(
),
)
elif dataset_name == "geometry":
# Multi-modal geometry dataset with 8 samples
return StorageConfig(
name=dataset_name,
path=os.path.join(os.path.dirname(__file__), "template", "data", "geometry"),
Expand Down
65 changes: 65 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
RayUnittestBase,
TensorBoardParser,
get_checkpoint_path,
get_lora_config,
get_model_path,
get_template_config,
get_unittest_dataset_config,
Expand Down Expand Up @@ -724,3 +725,67 @@ def test_trainer(self):
def tearDown(self):
# remove dir only when the test passed
shutil.rmtree(self.config.checkpoint_job_dir)


class TestTrainerLoRA(BaseTrainerCase):
def test_trainer(self):
"""Test both mode with LoRA request."""
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.buffer.explorer_input.eval_tasksets.append(
get_unittest_dataset_config("gsm8k", "test")
)
self.config.model.model_path = get_model_path()
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.kl_loss_fn = "none"
self.config.algorithm.repeat_times = 4
self.config.buffer.batch_size = 4
self.config.buffer.total_steps = 2
self.config.cluster.node_num = 1
self.config.cluster.gpu_per_node = 4
self.config.explorer.eval_interval = 2
self.config.model.lora_configs = [get_lora_config()]
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
self.config.synchronizer.sync_interval = 2
self.config.trainer.save_interval = 2
self.config.check_and_update()
both(self.config)
# check metrics are available
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2)
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 2)
ray.shutdown(_exiting_interpreter=True)
# check save lastest checkpoint
checkpoint_step_2, step_num = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type=self.config.trainer.trainer_type,
)
self.assertTrue(len(os.listdir(os.path.join(checkpoint_step_2, "actor"))) > 0)
self.assertTrue(
len(os.listdir(os.path.join(checkpoint_step_2, "actor", "lora_adapter"))) > 0
)
self.assertEqual(step_num, 2)

# test bench mode
ray.init(ignore_reinit_error=True, namespace=self.config.ray_namespace)
self.config.mode = "bench"
self.config.synchronizer.sync_method = SyncMethod.CHECKPOINT
self.config.explorer.bench_on_latest_checkpoint = False
self.config.check_and_update()
bench(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))
for prefix in ["eval", "bench"]:
gsm8k_metrics = parser.metric_list(f"{prefix}/gsm8k")
self.assertTrue(len(gsm8k_metrics) > 0)
gsm8k_metric_steps = parser.metric_steps(gsm8k_metrics[0])
self.assertEqual([0, 2], gsm8k_metric_steps)

def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)
Loading