diff --git a/trinity/common/config.py b/trinity/common/config.py index 2efec5b24a..ea9f18e8b9 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -646,6 +646,7 @@ class TrainerConfig: # if None, automatically set to ceil(2 * model.max_model_len / ulysses_sequence_parallel_size) max_token_len_per_gpu: Optional[int] = None ulysses_sequence_parallel_size: int = 1 # sp size + fix_actor_microbatch_loss_scale: bool = False # EXPERIMENTAL # TODO: extract more train-related params from underlying trainer engine save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index ff017340df..49a241393a 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -136,6 +136,7 @@ class Actor: ppo_micro_batch_size_per_gpu: int = 1 use_dynamic_bsz: Optional[bool] = None ppo_max_token_len_per_gpu: Optional[int] = None + fix_actor_microbatch_loss_scale: Optional[bool] = None # EXPERIMENTAL grad_clip: Optional[float] = None ppo_epochs: int = 1 shuffle: bool = False @@ -427,6 +428,10 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 self.actor_rollout_ref.actor.ppo_max_token_len_per_gpu = ( config.trainer.max_token_len_per_gpu ) + if self.actor_rollout_ref.actor.fix_actor_microbatch_loss_scale is None: + self.actor_rollout_ref.actor.fix_actor_microbatch_loss_scale = ( + config.trainer.fix_actor_microbatch_loss_scale + ) if self.actor_rollout_ref.actor.ulysses_sequence_parallel_size is None: self.actor_rollout_ref.actor.ulysses_sequence_parallel_size = ( config.trainer.ulysses_sequence_parallel_size diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 3ca2d65cc7..5b9a5256c7 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -88,6 +88,16 @@ def update_policy(self, data: DataProto): # noqa: C901 mini_batches = data.split(self.config.ppo_mini_batch_size) + # EXPERIMENTAL: apply loss scale fix + loss_agg_mode = ( + self.policy_loss_fn.loss_agg_mode + if hasattr(self.policy_loss_fn, "loss_agg_mode") + else "token-mean" + ) + do_fix_actor_microbatch_loss_scale = self.config.fix_actor_microbatch_loss_scale and ( + loss_agg_mode == "token-mean" + ) + metrics = {} for _ in range(self.config.ppo_epochs): for batch_idx, mini_batch in enumerate(mini_batches): @@ -104,6 +114,12 @@ def update_policy(self, data: DataProto): # noqa: C901 ) micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + if do_fix_actor_microbatch_loss_scale: + # calculate the total number of response tokens in the minibatch + mini_batch_token_num = torch.sum( + mini_batch.batch["response_mask"].to(get_device_id()) + ).item() + self.actor_optimizer.zero_grad() for micro_batch in micro_batches: @@ -156,13 +172,19 @@ def update_policy(self, data: DataProto): # noqa: C901 ) policy_loss = policy_loss + kl_loss - if self.config.use_dynamic_bsz: - # relative to the dynamic bsz - loss = policy_loss * ( - response_mask.shape[0] / self.config.ppo_mini_batch_size - ) + # set loss scale for the microbatch + if not do_fix_actor_microbatch_loss_scale: + # original implementation of microbatch loss scale + if self.config.use_dynamic_bsz: + loss_scale = response_mask.shape[0] / self.config.ppo_mini_batch_size + else: + loss_scale = 1.0 / self.gradient_accumulation else: - loss = policy_loss / self.gradient_accumulation + # EXPERIMENTAL: fix for token-mean loss aggregation + # scale microbatch loss according to the number of tokens (rather than sequences) + loss_scale = torch.sum(response_mask).item() / (mini_batch_token_num + 1e-6) + + loss = policy_loss * loss_scale loss.backward() append_to_dict(metrics, micro_batch_metrics)