From 952c2d364bbcf4ac32d24dba8dc3a3befd96e7f8 Mon Sep 17 00:00:00 2001 From: Yuchang Sun <52027540+hiyuchang@users.noreply.github.com> Date: Tue, 12 Aug 2025 14:11:18 +0800 Subject: [PATCH] [Feat] Allow user to set `train_batch_size` (#177) --- .github/workflows/docker/docker-compose.yaml | 1 + .../source/tutorial/example_async_mode.md | 2 +- .../sphinx_doc/source/tutorial/example_dpo.md | 4 +- .../source/tutorial/example_mix_algo.md | 41 +++++++++---------- docs/sphinx_doc/source/tutorial/faq.md | 4 +- .../source/tutorial/trinity_configs.md | 4 +- examples/async_gsm8k/trainer.yaml | 2 +- examples/dpo_humanlike/dpo.yaml | 2 +- examples/mix_math/mix_math.yaml | 6 +-- examples/sft_mot/sft.yaml | 2 +- tests/buffer/queue_test.py | 14 +++---- tests/buffer/sql_test.py | 2 +- tests/common/vllm_test.py | 17 ++++++-- tests/explorer/scheduler_test.py | 14 ++++--- tests/manager/synchronizer_test.py | 2 + tests/template/config.yaml | 1 - tests/template/verl_config.yaml | 2 - tests/trainer/trainer_test.py | 6 ++- .../policy_loss_fn/mix_policy_loss.py | 23 +++++------ .../sample_strategy/mix_sample_strategy.py | 6 +-- trinity/buffer/queue.py | 2 +- trinity/buffer/ray_wrapper.py | 4 +- trinity/buffer/reader/file_reader.py | 4 +- trinity/buffer/reader/queue_reader.py | 2 +- trinity/common/config.py | 17 ++++++-- trinity/common/verl_config.py | 20 +++------ trinity/manager/config_manager.py | 9 ++-- .../config_registry/buffer_config_manager.py | 28 ++++++++++++- trinity/trainer/verl/fsdp_workers.py | 4 +- 29 files changed, 142 insertions(+), 103 deletions(-) diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index 21d28464f5..b150738bfd 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -9,6 +9,7 @@ services: - CHECKPOINT_ROOT_DIR=/mnt/checkpoints - DATA_ROOT_DIR=/mnt/data - MODEL_PATH=/mnt/models/Qwen3-0.6B + - API_MODEL_PATH=/mnt/models/Qwen3-1.7B - CHECKPOINT_PATH=/mnt/checkpoints working_dir: /workspace networks: diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md index 70ca66e2b2..6b84df34d4 100644 --- a/docs/sphinx_doc/source/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md @@ -74,7 +74,7 @@ cluster: gpu_per_node: 4 buffer: total_epochs: 1 - batch_size: 64 + train_batch_size: 512 explorer_input: taskset: name: gsm8k diff --git a/docs/sphinx_doc/source/tutorial/example_dpo.md b/docs/sphinx_doc/source/tutorial/example_dpo.md index f457c6e888..040ff1ad9f 100644 --- a/docs/sphinx_doc/source/tutorial/example_dpo.md +++ b/docs/sphinx_doc/source/tutorial/example_dpo.md @@ -61,7 +61,7 @@ cluster: gpu_per_node: 8 buffer: total_epochs: 2 - batch_size: 64 + train_batch_size: 64 trainer_input: experience_buffer: name: human_like_dpo @@ -95,7 +95,7 @@ cluster: gpu_per_node: 2 buffer: total_epochs: 5 - batch_size: 64 + train_batch_size: 64 trainer_input: experience_buffer: name: diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 473cf2be1f..632ebcea7b 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -80,12 +80,12 @@ class MixSampleStrategy(SampleStrategy): def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) - tot_batch_size = buffer_config.read_batch_size + tot_batch_size = buffer_config.train_batch_size expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) # experience buffer usual_buffer_config = copy.deepcopy(buffer_config) - usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size + usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size self.usual_exp_buffer = get_buffer_reader( buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore ) @@ -97,7 +97,7 @@ class MixSampleStrategy(SampleStrategy): # expert experience buffer expert_buffer_config = copy.deepcopy(buffer_config) - expert_buffer_config.read_batch_size = expert_batch_size + expert_buffer_config.train_batch_size = expert_batch_size self.expert_exp_buffer = get_buffer_reader( buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config ) @@ -157,23 +157,20 @@ class MIXPolicyLossFn(PolicyLossFn): clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, use_dynamic_bsz: Optional[bool] = None, - repeat_times: Optional[int] = None, - ppo_mini_batch_size: Optional[int] = None, - ppo_micro_batch_size_per_gpu: Optional[int] = None, - ngpus_trainer: Optional[int] = None, - read_batch_size_usual: Optional[int] = None, - read_batch_size_expert: Optional[int] = None, + ppo_mini_batch_size: int = 1, + ppo_micro_batch_size_per_gpu: int = 1, + ngpus_trainer: int = 1, + train_batch_size_usual: int = 1, + train_batch_size_expert: int = 1, use_token_level_loss_in_sft: bool = True, ) -> None: super().__init__(backend=backend) self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz - self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore - self.gradient_accumulation = ( - ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore - ) - self.read_batch_size_usual = read_batch_size_usual - self.read_batch_size_expert = read_batch_size_expert + self.experience_per_gpu = ppo_mini_batch_size // ngpus_trainer + self.gradient_accumulation = ppo_mini_batch_size // ppo_micro_batch_size_per_gpu + self.train_batch_size_usual = train_batch_size_usual + self.train_batch_size_expert = train_batch_size_expert self.grpo_loss_fn = PPOPolicyLossFn( clip_range=clip_range, clip_range_low=clip_range_low, @@ -199,14 +196,14 @@ class MIXPolicyLossFn(PolicyLossFn): if self.use_dynamic_bsz: per_micro_batch_weight_usual = self.experience_per_gpu / ( - logprob.shape[0] * self.read_batch_size_usual + logprob.shape[0] * self.train_batch_size_usual ) per_micro_batch_weight_expert = self.experience_per_gpu / ( - logprob.shape[0] * self.read_batch_size_expert + logprob.shape[0] * self.train_batch_size_expert ) else: - per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore - per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore + per_micro_batch_weight_usual = self.gradient_accumulation / self.train_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.train_batch_size_expert # type: ignore if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( @@ -272,11 +269,11 @@ algorithm: use_token_level_loss_in_sft: False use_dynamic_bsz: False repeat_times: 8 - ppo_mini_batch_size: 32 + ppo_mini_batch_size: 256 ppo_micro_batch_size_per_gpu: 4 ngpus_trainer: 4 - read_batch_size_expert: 64 - read_batch_size_usual: 192 + train_batch_size_expert: 64 + train_batch_size_usual: 192 ``` With the above configurations, the experiment can be run with the following command: diff --git a/docs/sphinx_doc/source/tutorial/faq.md b/docs/sphinx_doc/source/tutorial/faq.md index cc6c3c461b..63b0fe5b1f 100644 --- a/docs/sphinx_doc/source/tutorial/faq.md +++ b/docs/sphinx_doc/source/tutorial/faq.md @@ -18,14 +18,14 @@ For users' convenience, future versions will gradually reduce parameters in `tra **A:** The following parameters are closely related: - `buffer.batch_size`: The number of tasks in a batch, effective for both the explorer and the trainer. -- `actor_rollout_ref.actor.ppo_mini_batch_size`: In the configuration, this value represents the number of tasks in a mini-batch, overridden by `buffer.batch_size`; but in the `update_policy` function, its value becomes the number of experiences in a mini-batch per GPU, i.e., `buffer.batch_size * algorithm.repeat_times (/ ngpus_trainer)`. The expression of dividing `ngpus_trainer` is caused by implict data allocation to GPUs, but this do not affects the result after gradient accumulation. +- `actor_rollout_ref.actor.ppo_mini_batch_size`: The number of experiences in a mini-batch, overridden by `buffer.train_batch_size`; but in the `update_policy` function, its value becomes the number of experiences in a mini-batch per GPU, i.e., `buffer.train_batch_size (/ ngpus_trainer)`. The expression of dividing `ngpus_trainer` is caused by implict data allocation to GPUs, but this do not affects the result after gradient accumulation. - `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu`: The number of experiences in a micro-batch per GPU. A minimal example showing their usage is as follows: ```python def update_policy(batch_exps): - dataloader = batch_epxs.split(ppo_mini_batch_size) # here `ppo_mini_batch_size` is in terms of experiences + dataloader = batch_exps.split(ppo_mini_batch_size) for _ in range(ppo_epochs): for batch_idx, data in enumerate(dataloader): # Split data diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 877a3df717..73af4ab254 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -162,6 +162,7 @@ Configures the data buffers used by the explorer and trainer. ```yaml buffer: batch_size: 32 + train_batch_size: 256 total_epochs: 100 explorer_input: @@ -184,6 +185,7 @@ buffer: ``` - `batch_size`: Number of tasks used per training step. *Please do not multiply this value by the `algorithm.repeat_times` manually*. +- `train_batch_size`: Number of experiences used per training step. Defaults to `batch_size` * `algorithm.repeat_times`. - `total_epochs`: Total number of training epochs. - `total_steps`: Optional. The total number of training steps. If specified, `total_epochs` will be ignored. @@ -440,7 +442,6 @@ actor_rollout_ref: impl_backend: None actor: strategy: fsdp # This is for backward-compatibility - ppo_mini_batch_size: 128 # ppo_micro_batch_size: 8 # will be deprecated, use ppo_micro_batch_size_per_gpu ppo_micro_batch_size_per_gpu: 4 use_dynamic_bsz: True @@ -505,7 +506,6 @@ critic: min_num_params: 0 fsdp_size: -1 forward_prefetch: False - ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size_per_gpu: 8 forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} diff --git a/examples/async_gsm8k/trainer.yaml b/examples/async_gsm8k/trainer.yaml index 27eb891559..f1f6b08b86 100644 --- a/examples/async_gsm8k/trainer.yaml +++ b/examples/async_gsm8k/trainer.yaml @@ -14,7 +14,7 @@ cluster: gpu_per_node: 4 buffer: total_epochs: 1 - batch_size: 96 + train_batch_size: 768 max_retry_times: 3 max_retry_interval: 1 explorer_input: diff --git a/examples/dpo_humanlike/dpo.yaml b/examples/dpo_humanlike/dpo.yaml index 650e8394fd..6f6dd06fd8 100644 --- a/examples/dpo_humanlike/dpo.yaml +++ b/examples/dpo_humanlike/dpo.yaml @@ -16,7 +16,7 @@ cluster: gpu_per_node: 8 buffer: total_epochs: 2 - batch_size: 32 + train_batch_size: 64 max_retry_times: 3 max_retry_interval: 1 trainer_input: diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml index 36d7ff911e..e4c4dd9172 100644 --- a/examples/mix_math/mix_math.yaml +++ b/examples/mix_math/mix_math.yaml @@ -12,11 +12,11 @@ algorithm: use_token_level_loss_in_sft: False use_dynamic_bsz: False repeat_times: 8 - ppo_mini_batch_size: 32 + ppo_mini_batch_size: 256 ppo_micro_batch_size_per_gpu: 4 ngpus_trainer: 4 - read_batch_size_expert: 64 - read_batch_size_usual: 192 + train_batch_size_expert: 64 + train_batch_size_usual: 192 model: model_path: /PATH/TO/MODEL/ max_response_tokens: 10240 diff --git a/examples/sft_mot/sft.yaml b/examples/sft_mot/sft.yaml index 13a0183edd..a309e75c8b 100644 --- a/examples/sft_mot/sft.yaml +++ b/examples/sft_mot/sft.yaml @@ -13,7 +13,7 @@ cluster: gpu_per_node: 8 buffer: total_epochs: 1 - batch_size: 32 + train_batch_size: 64 max_retry_times: 3 max_retry_interval: 1 trainer_input: diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 74886ec761..27ee3cb0de 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -55,10 +55,10 @@ async def test_queue_buffer(self, name, use_priority_queue): exp.info = {"model_version": 0, "use_count": 0} for _ in range(self.total_num // self.put_batch_size): await writer.write_async(exps) - for _ in range(self.total_num // self.read_batch_size): + for _ in range(self.total_num // self.train_batch_size): exps = reader.read() - self.assertEqual(len(exps), self.read_batch_size) - print(f"finish read {self.read_batch_size} experience") + self.assertEqual(len(exps), self.train_batch_size) + print(f"finish read {self.train_batch_size} experience") exps = [ Experience( tokens=torch.tensor([float(j) for j in range(i + 1)]), @@ -94,13 +94,13 @@ def thread_read(reader, result_queue): async def test_priority_queue_capacity(self): # test queue capacity - self.config.read_batch_size = 4 + self.config.train_batch_size = 4 meta = StorageConfig( name="test_buffer_small", algorithm_type="ppo", storage_type=StorageType.QUEUE, max_read_timeout=1, - capacity=100, # priority will use 2 * read_batch_size as capacity (8) + capacity=100, # priority will use 2 * train_batch_size as capacity (8) path=BUFFER_FILE_PATH, use_priority_queue=True, replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6}, @@ -303,12 +303,12 @@ def replace_call(): def setUp(self): self.total_num = 8 self.put_batch_size = 2 - self.read_batch_size = 4 + self.train_batch_size = 4 self.config = BufferConfig( max_retry_times=3, max_retry_interval=1, - read_batch_size=self.read_batch_size, + train_batch_size=self.train_batch_size, ) if os.path.exists(BUFFER_FILE_PATH): os.remove(BUFFER_FILE_PATH) diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 7d43d04168..33e8a24462 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -28,7 +28,7 @@ async def test_create_sql_buffer(self) -> None: config = BufferConfig( max_retry_times=3, max_retry_interval=1, - read_batch_size=read_batch_size, + train_batch_size=read_batch_size, ) sql_writer = SQLWriter(meta, config) sql_reader = SQLReader(meta, config) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index a1d47a48ff..b35c0ecb31 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -23,6 +23,15 @@ def get_model_path() -> str: return path +def get_api_model_path() -> str: + path = os.environ.get("API_MODEL_PATH") + if not path: + raise EnvironmentError( + "Please set `export API_MODEL_PATH=` before running this test." + ) + return path + + DEBUG = False @@ -322,7 +331,7 @@ class TestAPIServerToolCall(RayUnittestBase): def setUp(self): self.config = get_template_config() self.config.mode = "explore" - self.config.model.model_path = get_model_path() + self.config.model.model_path = get_api_model_path() self.config.explorer.rollout_model.engine_type = "vllm_async" self.config.explorer.rollout_model.engine_num = 1 self.config.explorer.rollout_model.tensor_parallel_size = 1 @@ -345,11 +354,13 @@ def setUp(self): ) def test_api_tool_calls(self): - """Tests the full conversation flow of a tool call via the OpenAI API.""" + """Tests the full conversation flow of a tool call via the OpenAI API. + Note: This test require a model that supports tool calls and thinking mode, e.g. Qwen3-1.7B. + """ import json import time - tokenizer = AutoTokenizer.from_pretrained(get_model_path()) + tokenizer = AutoTokenizer.from_pretrained(get_api_model_path()) print_debug("\n\n" + "=" * 30 + " Running test_api_tool_calls " + "=" * 30) start_time = time.time() diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 82dafe7c1a..b1e7528b5e 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -207,7 +207,7 @@ def setUp(self): self.config.explorer.max_retry_times = 1 self.config.explorer.max_timeout = 5 self.config.explorer.runner_per_model = 2 - self.config.buffer.read_batch_size = 2 + self.config.buffer.train_batch_size = 2 self.config.buffer.pad_token_id = 0 self.config.buffer.explorer_output = ( self.config.buffer.trainer_input.experience_buffer @@ -568,11 +568,13 @@ async def test_non_repeatable_workflow(self): ) async def test_stepwise_experience_eid(self): + task_num, repeat_times, step_num = 2, 4, 3 + self.config.buffer.batch_size = task_num + self.config.buffer.train_batch_size = task_num * repeat_times * step_num self.config.explorer.max_repeat_times_per_runner = 2 self.config.check_and_update() scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()]) await scheduler.start() - task_num, repeat_times, step_num = 2, 4, 3 batch_num = 2 # repeatable stepwise workflow @@ -584,8 +586,8 @@ async def test_stepwise_experience_eid(self): scheduler.schedule(tasks, batch_id=i) statuses, _ = await scheduler.get_results(batch_id=i) self.assertEqual(len(statuses), task_num * repeat_times / 2) - exps = self.queue.read(batch_size=task_num * repeat_times * step_num) - self.assertEqual(len(exps), task_num * repeat_times * step_num) + exps = self.queue.read(batch_size=self.config.buffer.train_batch_size) + self.assertEqual(len(exps), self.config.buffer.train_batch_size) exp_list.extend(exps) # test task_id, run_id and unique_id @@ -605,8 +607,8 @@ async def test_stepwise_experience_eid(self): scheduler.schedule(tasks, batch_id=i) statuses, _ = await scheduler.get_results(batch_id=i) self.assertEqual(len(statuses), task_num * repeat_times / 2) - exps = self.queue.read(batch_size=task_num * repeat_times * step_num) - self.assertEqual(len(exps), task_num * repeat_times * step_num) + exps = self.queue.read(batch_size=self.config.buffer.train_batch_size) + self.assertEqual(len(exps), self.config.buffer.train_batch_size) exp_list.extend(exps) # test task_id, run_id and unique_id diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index ab1ed103b9..a77bc62097 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -129,6 +129,7 @@ def test_synchronizer(self): config.monitor.monitor_type = "tensorboard" trainer_config = deepcopy(config) trainer_config.mode = "train" + trainer_config.buffer.train_batch_size = 4 trainer_config.check_and_update() explorer1_config = deepcopy(config) @@ -253,6 +254,7 @@ def test_synchronizer(self): config.monitor.monitor_type = "tensorboard" trainer_config = deepcopy(config) trainer_config.mode = "train" + trainer_config.buffer.train_batch_size = 4 trainer_config.check_and_update() explorer1_config = deepcopy(config) diff --git a/tests/template/config.yaml b/tests/template/config.yaml index ec954829f0..aa903fd667 100644 --- a/tests/template/config.yaml +++ b/tests/template/config.yaml @@ -14,7 +14,6 @@ algorithm: lam: 1.0 kl_penalty_fn: k3 kl_loss_fn: k2 - model: model_path: '' max_response_tokens: 2048 diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml index 96522887a7..1027ef31d0 100644 --- a/tests/template/verl_config.yaml +++ b/tests/template/verl_config.yaml @@ -7,7 +7,6 @@ actor_rollout_ref: use_remove_padding: True actor: strategy: fsdp # This is for backward-compatibility - ppo_mini_batch_size: 4 ppo_micro_batch_size_per_gpu: 1 use_dynamic_bsz: True ppo_max_token_len_per_gpu: 16384 @@ -70,7 +69,6 @@ critic: min_num_params: 0 fsdp_size: -1 forward_prefetch: False - ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size_per_gpu: 1 forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 7c11a67b4f..ae320dd709 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -124,7 +124,7 @@ def tearDown(self): class TestStepAheadAsyncRL(BaseTrainerCase): def test_trainer(self): - """Test the explore step ahead trainer""" + """Test the explore step ahead trainer.""" # train 4 step, sync_offset=1, sync_interval=2 # Explorer: # | 1 | 2 | 3 |sync| 4 | @@ -274,7 +274,7 @@ def test_trainer(self): self.config.buffer.total_epochs = 2 self.config.buffer.total_steps = 4 # step has higher priority than epoch self.config.synchronizer.sync_interval = 4 - # self.config.buffer.batch_size = 32 + self.config.buffer.train_batch_size = 8 self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo") self.config.check_and_update() self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 @@ -301,6 +301,7 @@ def test_trainer(self): self.config.algorithm.kl_loss_fn = "none" self.config.algorithm.entropy_loss_fn = "none" self.config.synchronizer.sync_interval = 4 + self.config.buffer.train_batch_size = 4 self.config.buffer.total_epochs = 2 self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config( "sft_for_gsm8k" @@ -367,6 +368,7 @@ def test_fully_async_mode(self, name, use_priority_queue): config.monitor.monitor_type = "tensorboard" trainer_config = deepcopy(config) trainer_config.mode = "train" + trainer_config.buffer.train_batch_size = 4 trainer_config.check_and_update() explorer1_config = deepcopy(config) diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 37f20f0236..b8c6a54a50 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -32,23 +32,20 @@ def __init__( clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, use_dynamic_bsz: Optional[bool] = None, - repeat_times: int = 1, ppo_mini_batch_size: int = 1, ppo_micro_batch_size_per_gpu: int = 1, ngpus_trainer: int = 1, - read_batch_size_usual: int = 1, - read_batch_size_expert: int = 1, + train_batch_size_usual: int = 1, + train_batch_size_expert: int = 1, use_token_level_loss_in_sft: bool = True, ) -> None: super().__init__(backend=backend) self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz - self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer - self.gradient_accumulation = ( - ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu - ) - self.read_batch_size_usual = read_batch_size_usual // ngpus_trainer - self.read_batch_size_expert = read_batch_size_expert // ngpus_trainer + self.experience_per_gpu = ppo_mini_batch_size // ngpus_trainer + self.gradient_accumulation = ppo_mini_batch_size // ppo_micro_batch_size_per_gpu + self.train_batch_size_usual = train_batch_size_usual // ngpus_trainer + self.train_batch_size_expert = train_batch_size_expert // ngpus_trainer self.grpo_loss_fn = PPOPolicyLossFn( clip_range=clip_range, clip_range_low=clip_range_low, @@ -74,14 +71,14 @@ def __call__( # type: ignore if self.use_dynamic_bsz: per_micro_batch_weight_usual = self.experience_per_gpu / ( - logprob.shape[0] * self.read_batch_size_usual + logprob.shape[0] * self.train_batch_size_usual ) per_micro_batch_weight_expert = self.experience_per_gpu / ( - logprob.shape[0] * self.read_batch_size_expert + logprob.shape[0] * self.train_batch_size_expert ) else: - per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore - per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore + per_micro_batch_weight_usual = self.gradient_accumulation / self.train_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.train_batch_size_expert # type: ignore if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 9508fd1dd8..32a34834bf 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -22,12 +22,12 @@ class MixSampleStrategy(SampleStrategy): def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) - tot_batch_size = buffer_config.read_batch_size + tot_batch_size = buffer_config.train_batch_size expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) # experience buffer usual_buffer_config = copy.deepcopy(buffer_config) - usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size + usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size self.usual_exp_buffer = get_buffer_reader( buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore ) @@ -39,7 +39,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): # expert experience buffer expert_buffer_config = copy.deepcopy(buffer_config) - expert_buffer_config.read_batch_size = expert_batch_size + expert_buffer_config.train_batch_size = expert_batch_size self.expert_exp_buffer = get_buffer_reader( buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config ) diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index e28af726a8..5a046acc81 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -48,7 +48,7 @@ def get_queue(cls, storage_config: StorageConfig, config: BufferConfig) -> "Queu if storage_config.use_priority_queue: reuse_cooldown_time = storage_config.reuse_cooldown_time replay_buffer_kwargs = storage_config.replay_buffer_kwargs - capacity = min(storage_config.capacity, config.read_batch_size * 2) + capacity = min(storage_config.capacity, config.train_batch_size * 2) logger.info( f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {reuse_cooldown_time}." ) diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 4222a023b1..01271dad19 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -49,7 +49,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: self.logger.warning("Failed to create database, assuming it already exists.") self.session = sessionmaker(bind=self.engine) - self.batch_size = config.read_batch_size + self.batch_size = config.train_batch_size self.max_retry_times = config.max_retry_times self.max_retry_interval = config.max_retry_interval self.ref_count = 0 @@ -97,7 +97,7 @@ def read( raise NotImplementedError(f"Unsupported strategy {strategy} by SQLStorage") exp_list = [] - batch_size = batch_size or self.batch_size + batch_size = batch_size or self.batch_size # type: ignore while len(exp_list) < batch_size: if len(exp_list): self.logger.info("waiting for experiences...") diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 44785205bd..4de2e03fc6 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -120,7 +120,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.messages_key = meta.format.messages_key self.prompt_key = meta.format.prompt_key self.response_key = meta.format.response_key - self.read_batch_size = config.batch_size + self.read_batch_size = config.train_batch_size self.dataset = _HFBatchReader( load_dataset(meta.path, name=subset_name, split=self.split), name=meta.name, @@ -202,7 +202,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.prompt_key = meta.format.prompt_key self.chosen_key = meta.format.chosen_key self.rejected_key = meta.format.rejected_key - self.read_batch_size = config.batch_size + self.read_batch_size = config.train_batch_size self.dataset = _HFBatchReader( load_dataset(meta.path, name=subset_name, split=self.split), name=meta.name, diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index adecc1b170..3745730f22 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -19,7 +19,7 @@ class QueueReader(BufferReader): def __init__(self, storage_config: StorageConfig, config: BufferConfig): assert storage_config.storage_type == StorageType.QUEUE self.timeout = storage_config.max_read_timeout - self.read_batch_size = config.read_batch_size + self.read_batch_size = config.train_batch_size self.queue = QueueWrapper.get_wrapper(storage_config, config) def read( diff --git a/trinity/common/config.py b/trinity/common/config.py index 3b735bf4f3..982d6049b3 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -308,6 +308,7 @@ class BufferConfig: """Config for buffer.""" batch_size: int = 1 + train_batch_size: int = 0 # default to `batch_size` * `algorithm.n` total_epochs: int = 1 total_steps: Optional[int] = None @@ -323,7 +324,6 @@ class BufferConfig: max_retry_interval: int = 1 # ! DO NOT SET FOLLOWING FIELDS - read_batch_size: int = 1 # automatically set tokenizer_path: Optional[str] = None # automatically set pad_token_id: Optional[int] = None # automatically set cache_dir: Optional[str] = None # automatically set @@ -651,8 +651,19 @@ def _check_buffer(self) -> None: # noqa: C901 exp_pipeline_output_buffers.name ] - # set read_batch_size / pad_token_id / tokenizer_path - self.buffer.read_batch_size = self.buffer.batch_size * self.algorithm.repeat_times + # check train_batch_size + if not self.buffer.train_batch_size: + if self.mode == "train" or self.algorithm.algorithm_type in ["sft", "dpo"]: + raise ValueError( + "`buffer.train_batch_size` is required when `mode` is 'train' or `algorithm.algorithm_type` is " + "'sft' or 'dpo'" + ) + logger.info( + "`buffer.train_batch_size` is set to `buffer.batch_size` * `algorithm.repeat_times`" + ) + self.buffer.train_batch_size = self.buffer.batch_size * self.algorithm.repeat_times + + # set pad_token_id / tokenizer_path if self.buffer.pad_token_id is None: from transformers import AutoTokenizer diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index d03bbf3b74..e203378987 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -13,7 +13,7 @@ @dataclass class Data: - train_batch_size: int = 1024 + train_batch_size: int = 1024 # kept for RayPPOTrainer._validate_config @dataclass @@ -72,9 +72,7 @@ class Actor: ppo_micro_batch_size: Optional[int] = None ppo_micro_batch_size_per_gpu: int = 1 use_dynamic_bsz: bool = False - ppo_max_token_len_per_gpu: int = ( - 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} - ) + ppo_max_token_len_per_gpu: int = 16384 grad_clip: float = 1.0 ppo_epochs: int = 1 shuffle: bool = False @@ -299,9 +297,9 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.trainer.n_gpus_per_node = config.cluster.gpu_per_node world_size = self.trainer.nnodes * self.trainer.n_gpus_per_node - if config.buffer.batch_size % world_size != 0: + if config.buffer.train_batch_size % world_size != 0: raise ValueError( - f"batch_size ({config.buffer.batch_size}) must be divisible by ({world_size})" + f"batch_size ({config.buffer.train_batch_size}) must be divisible by ({world_size})" ) self.trainer.sync_freq = config.synchronizer.sync_interval @@ -317,9 +315,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.trainer.resume_mode = "auto" self.buffer = config.buffer - # TODO: use dynamic read_batch_size to support multi-round scenarios - # Get the experiences of one explore step - self.data.train_batch_size = config.buffer.batch_size self.synchronizer = config.synchronizer self.actor_rollout_ref.synchronizer = config.synchronizer @@ -330,16 +325,13 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.model.custom_chat_template = config.model.custom_chat_template self.critic.model.path = config.model.critic_model_path self.critic.model.tokenizer_path = config.model.critic_model_path - self.actor_rollout_ref.actor.ppo_mini_batch_size = ( - config.buffer.batch_size - ) # TODO: may allow user to change + self.actor_rollout_ref.actor.ppo_mini_batch_size = config.buffer.train_batch_size self.actor_rollout_ref.rollout.temperature = ( config.buffer.explorer_input.taskset.rollout_args.temperature ) self.actor_rollout_ref.rollout.n = config.algorithm.repeat_times - self.critic.ppo_mini_batch_size = config.buffer.batch_size + self.critic.ppo_mini_batch_size = config.buffer.train_batch_size self.critic.rollout_n = self.actor_rollout_ref.rollout.n - self.critic.synchronizer = config.synchronizer if config.trainer.actor_grad_clip is not None: self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 379a3f008a..f241482dca 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -131,9 +131,9 @@ def beginner_mode(self): st.header("Important Configs") self.get_configs("node_num", "gpu_per_node", "engine_num", "tensor_parallel_size") - self.get_configs("total_epochs", "train_batch_size", "ppo_epochs", "repeat_times") + self.get_configs("total_epochs", "explore_batch_size", "train_batch_size", "repeat_times") - self.get_configs("storage_type", "max_response_tokens", "max_model_len") + self.get_configs("storage_type", "max_response_tokens", "max_model_len", "ppo_epochs") self.get_configs("sync_interval", "eval_interval", "save_interval") @@ -169,7 +169,7 @@ def _expert_model_part(self): self.get_configs("max_response_tokens", "max_model_len") def _expert_buffer_part(self): - self.get_configs("total_epochs", "train_batch_size") + self.get_configs("total_epochs", "explore_batch_size", "train_batch_size") self.get_configs( "default_workflow_type", "default_eval_workflow_type", "default_reward_fn_type" @@ -510,7 +510,8 @@ def _gen_buffer_config(self): ) # TODO buffer_config = { - "batch_size": st.session_state["train_batch_size"], + "batch_size": st.session_state["explore_batch_size"], + "train_batch_size": st.session_state["train_batch_size"], "total_epochs": st.session_state["total_epochs"], "explorer_input": {}, "trainer_input": { diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index e7ae3ca7cb..37cf11ea75 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -12,6 +12,16 @@ def set_total_epochs(**kwargs): st.number_input("Total Epochs", min_value=1, **kwargs) +@CONFIG_GENERATORS.register_config(default_value=96) +def set_explore_batch_size(**kwargs): + st.number_input( + "Task Batch Size", + min_value=1, + help="Number of tasks to explore in one explore step", + **kwargs, + ) + + def _str_for_train_batch_size(): trainer_gpu_num_str = ( "`gpu_per_node * node_num - engine_num * tensor_parallel_size`" @@ -19,6 +29,7 @@ def _str_for_train_batch_size(): else "`gpu_per_node * node_num`" ) return ( + f"Usually set to `task_batch_size` * `repeat_times`." f"Please ensure that `train_batch_size` can be divided by " f"{trainer_gpu_num_str} = {st.session_state['trainer_gpu_num']}." ) @@ -222,6 +233,21 @@ def set_default_workflow_type(**kwargs): ) +@CONFIG_GENERATORS.register_config(default_value="math_workflow") +def set_default_eval_workflow_type(**kwargs): + st.selectbox( + "Default Eval Workflow Type :orange-badge[(Needs review)]", + WORKFLOWS.modules.keys(), + help=r"""`simple_workflow`: call 'model.chat()' to get responses. + +`math_workflow`: call 'model.chat()' with a pre-defined system prompt to get responses. + +Other workflows: conduct multi-turn task for the given dataset. +""", + **kwargs, + ) + + @CONFIG_GENERATORS.register_config(default_value="math_reward") def set_default_reward_fn_type(**kwargs): st.selectbox( @@ -241,7 +267,7 @@ def set_default_reward_fn_type(**kwargs): def set_system_prompt(**kwargs): st.text_area( "System Prompt", - placeholder="System prompt is used to guide the model behavior.", + placeholder="""You are a helpful assistant that solves MATH problems....""", **kwargs, ) diff --git a/trinity/trainer/verl/fsdp_workers.py b/trinity/trainer/verl/fsdp_workers.py index eac7f94443..3d99bec597 100644 --- a/trinity/trainer/verl/fsdp_workers.py +++ b/trinity/trainer/verl/fsdp_workers.py @@ -150,7 +150,7 @@ def __init__(self, config: DictConfig, role: str): # normalize config if self._is_actor: - self.config.actor.ppo_mini_batch_size *= self.config.rollout.n + # note: no need to conduct `ppo_mini_batch_size *= rollout_n` anymore self.config.actor.ppo_mini_batch_size //= ( self.device_mesh.size() // self.ulysses_sequence_parallel_size ) @@ -904,7 +904,7 @@ def __init__(self, config): self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config - self.config.ppo_mini_batch_size *= self.config.rollout_n + # note: no need to conduct `ppo_mini_batch_size *= rollout_n` anymore self.config.ppo_mini_batch_size //= ( torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size )