-
Notifications
You must be signed in to change notification settings - Fork 238
[skyrl-train] Refactor training loop structure to explicitly batch at two levels (minibatch -> microbatch) #817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the training loop to use a two-level batching strategy (minibatch → microbatch), which is a good architectural improvement for supporting uneven microbatches in the future. The core logic changes are sound, but I've identified a few critical issues. The most significant one is that in both policy and critic training loops, metrics are not being aggregated correctly across microbatches, leading to inaccurate logging. Only the status from the last microbatch of a minibatch is being recorded. Additionally, the optimizer step has been unintentionally removed from the critic's training_step method, which will affect tests relying on it. I've provided detailed comments and suggestions to address these points.
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
| microbatch_iterator = BatchIterator( | ||
| minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| ) | ||
| num_microbatches = len(microbatch_iterator) | ||
| microbatch_weight = 1.0 / num_microbatches | ||
|
|
||
| for microbatch in microbatch_iterator: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The next step is basically just to change this microbatch iterator from one that's doing sample-based chunking to one that's token-based.
| return self | ||
|
|
||
| def __next__(self) -> Experience: | ||
| def __next__(self) -> TrainingInputBatch: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this change also affects the use of BatchIterator for the megatron backend, which implements ppo_train differently FSDP/Deepspeed.
| dataloader = BatchIterator( |
could you make sure the conversion to experience is also handled correctly for the megatron code path? Making sure one of these tests pass:
| async def test_megatron_train( |
is probably a good way to check this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. I added a followup TODO to update the megatron worker's ppo loop as well. Made a minimal change for now to prevent this PR from getting too large.
erictang000
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left a comment about megatron + i think the gemini comments are worth a look - essentially we want to make sure that the metrics are aggregated in the same way before and after this PR. Maybe it would be nice to show with running a test before and after this PR (maybe this guy:
| def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_kl_loss): |
…atch_refactor Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…atch_refactor Signed-off-by: Justin Yu <justinvyu@anyscale.com>
|
To sanity check that this PR didn't introduce any regressions for metrics, I printed the output status from this test: Printed status on master: Printed status with this PR: |
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
|
Tested Megatron code path as well: |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR refactors the training loops in PolicyWorkerBase and CriticWorkerBase to use a two-level batching structure (minibatch -> microbatch), which is a great improvement for supporting batch balancing and dynamic microbatch sizes. The changes are well-motivated and mostly well-executed.
My feedback focuses on a few key areas:
- An inconsistency in
megatron_worker.pywhere the new training loop structure has not been applied. - A potential bug in the calculation of
critic_update_steps. - Several opportunities for minor refactoring to improve code clarity and reduce duplication, such as extracting helper methods for status recording and memory snapshotting.
- Identifying potentially unused code marked with
TODOcomments that should be cleaned up.
Overall, this is a solid refactoring. Addressing these points will improve the consistency and maintainability of the codebase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good! I think just one super minor gemini comment, and we need to fix the failing cpu test and we should be good to merge this
cpu test:
=========================== short test summary info ============================
FAILED tests/cpu/test_trainer.py::test_ppo_train_batch_calculations - TypeError: test_ppo_train_batch_calculations.<locals>.mock_policy_forward_backward() got an unexpected keyword argument 'microbatch_weight'
============ 1 failed, 353 passed, 83 warnings in 101.86s (0:01:41) ============
|
|
||
| status_mean = reduce_metrics(all_metrics) | ||
| status_mean["policy_update_steps"] = policy_update_steps / accumulation_steps | ||
| status_mean["policy_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch...
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…atch_refactor Signed-off-by: Justin Yu <justinvyu@anyscale.com>
erictang000
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀🚀🚀
Summary
This PR refactors the training loop structure in both
PolicyWorkerBaseandCriticWorkerBaseto use a consistent two-level batching strategy (minibatch → microbatch) with amicrobatch_weightused for gradient accumulation.Refactored ppo_train to use a two-level loop: iterate over minibatches, then subdivide each minibatch into microbatches
Motivation
The motivation for this PR was for batch balancing which I tried adding in this PR: #640
The problem was that minibatch boundaries are not explicitly defined right now. We have 2 configurations
policy_mini_batch_sizeandmicro_train_batch_size_per_gpu, and the minibatch is implicitly constructed by doing gradient accumulation over(policy_mini_batch_size // micro_train_batch_size_per_gpu)identical microbatches. #640 breaks the "same sized microbatches" assumption, which is why we need to partition explicitly at the minibatch level first, before chunking into possibly uneven microbatches. That way, it's simpler to tell when to stop accumulating the gradient.Also, the introduction of a more general
microbatch_weightis also motivated by dynamic microbatch sizes introduced by #640. Each microbatch should contributeN_i / sum(N_j) * loss_ito the accumulated gradient. In the default case, this is just1/accumulation_steps.Testing