Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/mix_chord/mix_chord.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ algorithm:
ppo_micro_batch_size_per_gpu: 4
ngpus_trainer: 4
train_batch_size_expert: 64
train_batch_size_usual: 256 # (40 batchsize * (1 - 0.2 expert_data_ratio)) * 8 repeat times
train_batch_size_usual: 256 # 32 batchsize * 8 repeat times
model:
model_path: /PATH/TO/MODEL/
max_response_tokens: 10240
Expand All @@ -31,7 +31,8 @@ cluster:
gpu_per_node: 8
buffer:
total_epochs: 4
batch_size: 40
batch_size: 32
train_batch_size: 320
max_retry_times: 3
max_retry_interval: 1
explorer_input:
Expand Down
3 changes: 2 additions & 1 deletion examples/mix_math/mix_math.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ cluster:
gpu_per_node: 8
buffer:
total_epochs: 10
batch_size: 40
batch_size: 32
train_batch_size: 320
max_retry_times: 3
max_retry_interval: 1
explorer_input:
Expand Down
56 changes: 56 additions & 0 deletions tests/trainer/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,59 @@ def test_fully_async_mode(self, name, use_priority_queue):
def tearDown(self):
checkpoint_path = get_checkpoint_path()
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))


class TestTrainerMIX(BaseTrainerCase):
def test_trainer(self):
"""Test MIX algorithm."""
# gsm8k has 16 tasks, sft_for_gsm8k has 8 tasks
# total 4 steps, each step: read 4 tasks from gsm8k, 16 tasks from sft_for_gsm8k
self.config.algorithm.algorithm_type = "mix"
self.config.algorithm.repeat_times = 4
self.config.algorithm.sample_strategy = "mix"
self.config.algorithm.sample_strategy_args = {"expert_data_ratio": 0.5} # rft=4*4 : sft=16
self.config.algorithm.policy_loss_fn = "mix"
self.config.buffer.batch_size = 4
self.config.buffer.train_batch_size = 32
self.config.buffer.total_epochs = 1
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.synchronizer.sync_interval = 1
self.config.trainer.save_interval = 1
self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
"sft_for_gsm8k"
)
self.config.buffer.trainer_input.sft_warmup_dataset.total_epochs = 8 # test this works
self.config.check_and_update()
self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
both(self.config)
parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard"))

# test rollout metrics
rollout_metrics = parser.metric_list("rollout")
self.assertTrue(len(rollout_metrics) > 0)
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4)
self.assertEqual(
parser.metric_values("rollout/experience_count")[1], 16
) # 16 rft experiences
# test actor metrics
actor_metrics = parser.metric_list("actor")
self.assertTrue(len(actor_metrics) > 0)
expert_metrics = parser.metric_list("actor/expert/")
self.assertEqual(parser.metric_max_step(expert_metrics[0]), 4) # SFT
usual_metrics = parser.metric_list("actor/usual/")
self.assertEqual(parser.metric_max_step(usual_metrics[0]), 4) # RFT
response_metrics = parser.metric_list("response_length")
self.assertTrue(len(response_metrics) > 0)
self.assertEqual(parser.metric_min_step(response_metrics[0]), 1)
self.assertEqual(parser.metric_max_step(response_metrics[0]), 4)
# test save checkpoint at last step
checkpoint_dir, step_num = get_checkpoint_dir_with_step_num(
checkpoint_root_path=self.config.checkpoint_job_dir,
trainer_type="verl",
)
self.assertEqual(step_num, 4)
self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0)

def tearDown(self):
shutil.rmtree(self.config.checkpoint_job_dir)