diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml index caa44c573b..26d04852e1 100644 --- a/examples/mix_chord/mix_chord.yaml +++ b/examples/mix_chord/mix_chord.yaml @@ -21,7 +21,7 @@ algorithm: ppo_micro_batch_size_per_gpu: 4 ngpus_trainer: 4 train_batch_size_expert: 64 - train_batch_size_usual: 256 # (40 batchsize * (1 - 0.2 expert_data_ratio)) * 8 repeat times + train_batch_size_usual: 256 # 32 batchsize * 8 repeat times model: model_path: /PATH/TO/MODEL/ max_response_tokens: 10240 @@ -31,7 +31,8 @@ cluster: gpu_per_node: 8 buffer: total_epochs: 4 - batch_size: 40 + batch_size: 32 + train_batch_size: 320 max_retry_times: 3 max_retry_interval: 1 explorer_input: diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index f2bea04fe8..11eafb6cc9 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -25,7 +25,8 @@ cluster: gpu_per_node: 8 buffer: total_epochs: 10 - batch_size: 40 + batch_size: 32 + train_batch_size: 320 max_retry_times: 3 max_retry_interval: 1 explorer_input: diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index ae320dd709..cedf4c2983 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -456,3 +456,59 @@ def test_fully_async_mode(self, name, use_priority_queue): def tearDown(self): checkpoint_path = get_checkpoint_path() shutil.rmtree(os.path.join(checkpoint_path, "unittest")) + + +class TestTrainerMIX(BaseTrainerCase): + def test_trainer(self): + """Test MIX algorithm.""" + # gsm8k has 16 tasks, sft_for_gsm8k has 8 tasks + # total 4 steps, each step: read 4 tasks from gsm8k, 16 tasks from sft_for_gsm8k + self.config.algorithm.algorithm_type = "mix" + self.config.algorithm.repeat_times = 4 + self.config.algorithm.sample_strategy = "mix" + self.config.algorithm.sample_strategy_args = {"expert_data_ratio": 0.5} # rft=4*4 : sft=16 + self.config.algorithm.policy_loss_fn = "mix" + self.config.buffer.batch_size = 4 + self.config.buffer.train_batch_size = 32 + self.config.buffer.total_epochs = 1 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.synchronizer.sync_interval = 1 + self.config.trainer.save_interval = 1 + self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config( + "sft_for_gsm8k" + ) + self.config.buffer.trainer_input.sft_warmup_dataset.total_epochs = 8 # test this works + self.config.check_and_update() + self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20 + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + + # test rollout metrics + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + self.assertEqual( + parser.metric_values("rollout/experience_count")[1], 16 + ) # 16 rft experiences + # test actor metrics + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + expert_metrics = parser.metric_list("actor/expert/") + self.assertEqual(parser.metric_max_step(expert_metrics[0]), 4) # SFT + usual_metrics = parser.metric_list("actor/usual/") + self.assertEqual(parser.metric_max_step(usual_metrics[0]), 4) # RFT + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_min_step(response_metrics[0]), 1) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + # test save checkpoint at last step + checkpoint_dir, step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=self.config.checkpoint_job_dir, + trainer_type="verl", + ) + self.assertEqual(step_num, 4) + self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0) + + def tearDown(self): + shutil.rmtree(self.config.checkpoint_job_dir)