From 976413879ac6021e94d640faa35ae7321f7d32bc Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 19 Aug 2025 11:02:44 +0800 Subject: [PATCH 1/3] add unittest for mix --- examples/mix_chord/mix_chord.yaml | 5 +-- examples/mix_math/mix_math.yaml | 3 +- tests/template/config.yaml | 4 +-- tests/trainer/trainer_test.py | 56 +++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 5 deletions(-) diff --git a/examples/mix_chord/mix_chord.yaml b/examples/mix_chord/mix_chord.yaml index caa44c573b..26d04852e1 100644 --- a/examples/mix_chord/mix_chord.yaml +++ b/examples/mix_chord/mix_chord.yaml @@ -21,7 +21,7 @@ algorithm: ppo_micro_batch_size_per_gpu: 4 ngpus_trainer: 4 train_batch_size_expert: 64 - train_batch_size_usual: 256 # (40 batchsize * (1 - 0.2 expert_data_ratio)) * 8 repeat times + train_batch_size_usual: 256 # 32 batchsize * 8 repeat times model: model_path: /PATH/TO/MODEL/ max_response_tokens: 10240 @@ -31,7 +31,8 @@ cluster: gpu_per_node: 8 buffer: total_epochs: 4 - batch_size: 40 + batch_size: 32 + train_batch_size: 320 max_retry_times: 3 max_retry_interval: 1 explorer_input: diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index f2bea04fe8..11eafb6cc9 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -25,7 +25,8 @@ cluster: gpu_per_node: 8 buffer: total_epochs: 10 - batch_size: 40 + batch_size: 32 + train_batch_size: 320 max_retry_times: 3 max_retry_interval: 1 explorer_input: diff --git a/tests/template/config.yaml b/tests/template/config.yaml index aa903fd667..ff349e6d58 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -19,8 +19,8 @@ model: max_response_tokens: 2048 max_model_len: 4096 cluster: # 2 for explorer, 2 for trainer - node_num: 2 - gpu_per_node: 2 + node_num: 1 + gpu_per_node: 4 buffer: total_epochs: 1 batch_size: 4 diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index ae320dd709..9e4f344617 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -456,3 +456,59 @@ def test_fully_async_mode(self, name, use_priority_queue): def tearDown(self): checkpoint_path = get_checkpoint_path() shutil.rmtree(os.path.join(checkpoint_path, "unittest")) + + +class TestTrainerMIX(BaseTrainerCase): + def test_trainer(self): + """Test MIX algorithm.""" + # gsm8k has 16 tasks, sft_for_gsm8k has 8 tasks + # total 8 steps, each step: read 2 tasks from gsm8k, 8 tasks from sft_for_gsm8k + self.config.algorithm.algorithm_type = "mix" + self.config.algorithm.repeat_times = 4 + self.config.algorithm.sample_strategy = "mix" + self.config.algorithm.sample_strategy_args = {"expert_data_ratio": 0.5} # rft=2*4 : sft=8 + self.config.algorithm.policy_loss_fn = "mix" + self.config.buffer.batch_size = 2 + self.config.buffer.train_batch_size = 16 + self.config.buffer.total_epochs = 1 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.synchronizer.sync_interval = 1 + self.config.trainer.save_interval = 1 + self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config( + "sft_for_gsm8k" + ) + self.config.buffer.trainer_input.sft_warmup_dataset.total_epochs = 8 + self.config.check_and_update() + self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20 + self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + + # test rollout metrics + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + self.assertEqual( + parser.metric_values("rollout/experience_count")[1], 8 + ) # 8 rft experiences + # test actor metrics + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + expert_metrics = parser.metric_list("actor/expert/") + self.assertEqual(parser.metric_max_step(expert_metrics[0]), 8) # SFT + usual_metrics = parser.metric_list("actor/usual/") + self.assertEqual(parser.metric_max_step(usual_metrics[0]), 8) # RFT + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_min_step(response_metrics[0]), 1) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) + # test save checkpoint at last step + checkpoint_dir, step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=self.config.checkpoint_job_dir, + trainer_type="verl", + ) + self.assertEqual(step_num, 8) + self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0) + + def tearDown(self): + shutil.rmtree(self.config.checkpoint_job_dir) From c674fae2b038f89c04de46b9d3ac47a68257cf77 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 19 Aug 2025 12:26:25 +0800 Subject: [PATCH 2/3] fix config --- tests/template/config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/template/config.yaml b/tests/template/config.yaml index ff349e6d58..aa903fd667 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -19,8 +19,8 @@ model: max_response_tokens: 2048 max_model_len: 4096 cluster: # 2 for explorer, 2 for trainer - node_num: 1 - gpu_per_node: 4 + node_num: 2 + gpu_per_node: 2 buffer: total_epochs: 1 batch_size: 4 From f2e1e4d2c21ee0997791e36a8ec1adca9c6eeab4 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Wed, 20 Aug 2025 10:22:17 +0800 Subject: [PATCH 3/3] shorten to 4 steps --- tests/trainer/trainer_test.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 9e4f344617..cedf4c2983 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -462,14 +462,14 @@ class TestTrainerMIX(BaseTrainerCase): def test_trainer(self): """Test MIX algorithm.""" # gsm8k has 16 tasks, sft_for_gsm8k has 8 tasks - # total 8 steps, each step: read 2 tasks from gsm8k, 8 tasks from sft_for_gsm8k + # total 4 steps, each step: read 4 tasks from gsm8k, 16 tasks from sft_for_gsm8k self.config.algorithm.algorithm_type = "mix" self.config.algorithm.repeat_times = 4 self.config.algorithm.sample_strategy = "mix" - self.config.algorithm.sample_strategy_args = {"expert_data_ratio": 0.5} # rft=2*4 : sft=8 + self.config.algorithm.sample_strategy_args = {"expert_data_ratio": 0.5} # rft=4*4 : sft=16 self.config.algorithm.policy_loss_fn = "mix" - self.config.buffer.batch_size = 2 - self.config.buffer.train_batch_size = 16 + self.config.buffer.batch_size = 4 + self.config.buffer.train_batch_size = 32 self.config.buffer.total_epochs = 1 self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.synchronizer.sync_interval = 1 @@ -477,7 +477,7 @@ def test_trainer(self): self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config( "sft_for_gsm8k" ) - self.config.buffer.trainer_input.sft_warmup_dataset.total_epochs = 8 + self.config.buffer.trainer_input.sft_warmup_dataset.total_epochs = 8 # test this works self.config.check_and_update() self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 20 self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 @@ -487,27 +487,27 @@ def test_trainer(self): # test rollout metrics rollout_metrics = parser.metric_list("rollout") self.assertTrue(len(rollout_metrics) > 0) - self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 8) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) self.assertEqual( - parser.metric_values("rollout/experience_count")[1], 8 - ) # 8 rft experiences + parser.metric_values("rollout/experience_count")[1], 16 + ) # 16 rft experiences # test actor metrics actor_metrics = parser.metric_list("actor") self.assertTrue(len(actor_metrics) > 0) expert_metrics = parser.metric_list("actor/expert/") - self.assertEqual(parser.metric_max_step(expert_metrics[0]), 8) # SFT + self.assertEqual(parser.metric_max_step(expert_metrics[0]), 4) # SFT usual_metrics = parser.metric_list("actor/usual/") - self.assertEqual(parser.metric_max_step(usual_metrics[0]), 8) # RFT + self.assertEqual(parser.metric_max_step(usual_metrics[0]), 4) # RFT response_metrics = parser.metric_list("response_length") self.assertTrue(len(response_metrics) > 0) self.assertEqual(parser.metric_min_step(response_metrics[0]), 1) - self.assertEqual(parser.metric_max_step(response_metrics[0]), 8) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) # test save checkpoint at last step checkpoint_dir, step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=self.config.checkpoint_job_dir, trainer_type="verl", ) - self.assertEqual(step_num, 8) + self.assertEqual(step_num, 4) self.assertTrue(len(os.listdir(os.path.join(checkpoint_dir, "actor"))) > 0) def tearDown(self):