-
Notifications
You must be signed in to change notification settings - Fork 245
Add a configuration for max tokens per microbatch #640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a configuration for max tokens per microbatch #640
Conversation
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
|
|
||
|
|
||
| @pytest.mark.parametrize("worker_type", ["policy", "critic"]) | ||
| def test_max_tokens_per_microbatch(ray_init_fixture, cfg, worker_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: test a multi-worker case where you have padding microbatches
| from skyrl_train.distributed.utils import init_custom_process_group | ||
| from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl | ||
| from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics | ||
| from skyrl_train.workers.worker_utils import BatchIterator, TokenBasedBatchIterator, reduce_metrics |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| from skyrl_train.workers.worker_utils import BatchIterator, TokenBasedBatchIterator, reduce_metrics | |
| from skyrl_train.workers.worker_utils import BaseBatchIterator, SampleBasedBatchIterator, TokenBasedBatchIterator, reduce_metrics |
| if self.cfg.trainer.max_tokens_per_microbatch is not None: | ||
| dataloader = TokenBasedBatchIterator( | ||
| train_data, max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch | ||
| ) | ||
| accumulation_steps = dataloader.num_microbatches | ||
| else: | ||
| dataloader = BatchIterator( | ||
| train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| ) | ||
|
|
||
| micro_batches_per_mini_batch = ( | ||
| self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu | ||
| ) | ||
| # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. | ||
| accumulation_steps = micro_batches_per_mini_batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: Should the dynamic batch size related packing happen at the batch level or the mini-batch level?
The most efficient one would be at the batch level, but then we won't have a good way to get mini batches afterwards. It seems like we should just do this at the mini batch level first. (you'd need to introduce an explicit loop over mini batches)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also allows us to naturally define the accumulation steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the dynamic batch size related packing happen at the batch level or the mini-batch level?
Yeah it should also be done at the global batch level, but I'm leaving that for a follow-up optimization, see the end of the PR description. The global batch level batch balancing is useful to assign each worker a similar number of tokens.
This PR's main goal is the max tokens per microbatch parameter. The balancing of microbatches is kind of a convenient feature add-on that is a slight optimization.
Here's why the balancing at the microbatch level is useful:
- Worker 0 mini-batch contains 4 sequences of length: [10, 10, 5, 5]
- Worker 0 has a max microbatch token length of 15.
- If you just do a single front-to-back pass, you'll end up with 3 microbatches: [10], [10, 5], [5]
- However, the better way to chunk this would be [10, 5], [10, 5], which is what the balanced microbatching method does here.
- Basically, we want to minimize number of microbatches while satisfying max_tokens_per_microbatch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also allows us to naturally define the accumulation steps
Could you elaborate on this a bit more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry here's what I meant;
We have "mini batches" per worker that is set based on policy_mini_batch_size. This is separate from the per-worker shard you're referring to.
train_batch_size / policy_mini_batch_size dictates the number of optimization steps with the current training batch.
So for example, with train_batch_size=16, num_workers=8, policy_mini_batch_size=8 , each worker will receive a shard of 2 samples, and this wil be split further into train_batch_size/ policy_mini_batch_size=2 mini batches. Each mini batch will correspond to one optimization step.
For each mini batch per worker, we further split it into micro batches.
My point was that the dynamic batch size related binpacking should happen at the worker-level mini batches not the worker-level overall batch.
… two levels (minibatch -> microbatch) (#817) ## Summary This PR refactors the training loop structure in both `PolicyWorkerBase` and `CriticWorkerBase` to use a consistent two-level batching strategy (minibatch → microbatch) with a `microbatch_weight` used for gradient accumulation. Refactored ppo_train to use a two-level loop: iterate over minibatches, then subdivide each minibatch into microbatches * Changed forward_backward signature from accumulation_steps: int to microbatch_weight: float * Loss is now scaled by microbatch_weight (i.e., 1.0 / num_microbatches) instead of dividing by accumulation_steps * Optimizer step is now called once per minibatch (after all microbatches are processed), rather than conditionally based on step count ## Motivation The motivation for this PR was for batch balancing which I tried adding in this PR: #640 The problem was that minibatch boundaries are not explicitly defined right now. We have 2 configurations `policy_mini_batch_size` and `micro_train_batch_size_per_gpu`, and the minibatch is implicitly constructed by doing gradient accumulation over `(policy_mini_batch_size // micro_train_batch_size_per_gpu)` identical microbatches. #640 breaks the "same sized microbatches" assumption, which is why we need to partition explicitly at the minibatch level first, before chunking into possibly uneven microbatches. That way, it's simpler to tell when to stop accumulating the gradient. Also, the introduction of a more general `microbatch_weight` is also motivated by dynamic microbatch sizes introduced by #640. Each microbatch should contribute `N_i / sum(N_j) * loss_i` to the accumulated gradient. In the default case, this is just `1/accumulation_steps`. ## Testing * ✅ [FSDP metrics are equivalent to master](#817 (comment)) * ✅ [Megatron codepath still works](#817 (comment)) --------- Signed-off-by: Justin Yu <justinvyu@anyscale.com>
|
Closing this PR since it's a bit outdated. Will redo this now that #817 has landed. |
Summary
Introduces token-based binpacking for microbatches, enforcing a
max_tokens_per_microbatchlimit on the total number of real tokens in a microbatch.Replaces the even sample-based chunking
BatchIteratorwith aBalancedBatchIteratorthat respects the microbatch token limits.Problem
Previously, microbatches created by chunking sequences evenly so that each microbatch was a certain batch size (
micro_train_batch_size_per_gpu), which could lead to:Solution
Introduce
balanced_binpackingandBalancedBatchIteratorutilities that:API changes
BalancedBatchIteratorthat follows theBatchIteratorinterface, which is just an iterator overExperiences.Worker.forwardto also use theBatchIteratorinterface, rather than callTrainingBatch.chunkdirectly.BatchIteratorwithBalancedBatchIteratorwhen themax_tokens_per_microbatchis set, for theforward,Critic.ppo_trainandPolicy.ppo_train.Example
Example ppo_train call from the unit test:
Follow-up
This only does balanced microbatching at the individual worker level. We can further balance the minibatches sent to each worker at a global level to ensure every DP worker has roughly the same amount of work.
There's also different algorithms that we could plug into the microbatch chunking step.