Skip to content

Conversation

@justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Nov 7, 2025

Summary

Introduces token-based binpacking for microbatches, enforcing a max_tokens_per_microbatch limit on the total number of real tokens in a microbatch.

Replaces the even sample-based chunking BatchIterator with a BalancedBatchIterator that respects the microbatch token limits.

Problem

Previously, microbatches created by chunking sequences evenly so that each microbatch was a certain batch size (micro_train_batch_size_per_gpu), which could lead to:

  • Sequence packed microbatches can exceed GPU memory limits when sequences vary in length, so you need to conservatively set the micro batch size.
  • Inefficient GPU utilization due to uneven token distribution across microbatches

Solution

Introduce balanced_binpacking and BalancedBatchIterator utilities that:

  • Enforce max_tokens_per_microbatch — no microbatch exceeds the token limit
  • Roughly balance the microbatches so that we avoid straggler microbatches.
  • Ensure that every worker still receives the same number of microbatches to satisfy Deepspeed/FSDP requirements (every DP worker must do the same number of forward passes).

API changes

  • Adds a BalancedBatchIterator that follows the BatchIterator interface, which is just an iterator over Experiences.
  • Unifies Worker.forward to also use the BatchIterator interface, rather than call TrainingBatch.chunk directly.
  • Replaces the BatchIterator with BalancedBatchIterator when the max_tokens_per_microbatch is set, for the forward, Critic.ppo_train and Policy.ppo_train.

Example

Example ppo_train call from the unit test:

minibatch_sequences=tensor([[88, 96,  3, 23, 44, 65, 55,  0, 66, 49],
        [25, 85, 10, 86, 31, 55, 31, 51,  4, 56],
        [28, 38, 76,  8,  0,  0,  0,  0,  0,  0],
        [21, 97, 46, 75, 73,  0,  0,  0,  0,  0]])
minibatch_attention_mask=tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
Policy Train epoch [1/1]:   0%|          | 0/2 [00:00<?, ?it/s]
(DeepSpeedPolicyWorkerBase pid=182679) micro_batches_per_mini_batch=2
(DeepSpeedPolicyWorkerBase pid=182679) experience.sequences=tensor([[88, 96,  3, 23, 44, 65, 55,  0, 66, 49],
(DeepSpeedPolicyWorkerBase pid=182679)         [28, 38, 76,  8,  0,  0,  0,  0,  0,  0]])
Policy Train epoch [1/1]:  50%|█████     | 1/2 [00:01<00:01,  1.28s/it, pg=-0.203, glen=4, policy_lr=1e-6, ent=3.08]
(DeepSpeedPolicyWorkerBase pid=182679) experience.sequences=tensor([[25, 85, 10, 86, 31, 55, 31, 51,  4, 56],
(DeepSpeedPolicyWorkerBase pid=182679)         [21, 97, 46, 75, 73,  0,  0,  0,  0,  0]])

Follow-up

This only does balanced microbatching at the individual worker level. We can further balance the minibatches sent to each worker at a global level to ensure every DP worker has roughly the same amount of work.

There's also different algorithms that we could plug into the microbatch chunking step.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@SumanthRH SumanthRH marked this pull request as ready for review November 11, 2025 21:24
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>


@pytest.mark.parametrize("worker_type", ["policy", "critic"])
def test_max_tokens_per_microbatch(ray_init_fixture, cfg, worker_type):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: test a multi-worker case where you have padding microbatches

@SumanthRH SumanthRH self-requested a review November 13, 2025 01:22
@SumanthRH SumanthRH self-assigned this Nov 13, 2025
from skyrl_train.distributed.utils import init_custom_process_group
from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl
from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics
from skyrl_train.workers.worker_utils import BatchIterator, TokenBasedBatchIterator, reduce_metrics
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from skyrl_train.workers.worker_utils import BatchIterator, TokenBasedBatchIterator, reduce_metrics
from skyrl_train.workers.worker_utils import BaseBatchIterator, SampleBasedBatchIterator, TokenBasedBatchIterator, reduce_metrics

@SumanthRH SumanthRH self-requested a review November 15, 2025 01:55
Comment on lines +657 to +671
if self.cfg.trainer.max_tokens_per_microbatch is not None:
dataloader = TokenBasedBatchIterator(
train_data, max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch
)
accumulation_steps = dataloader.num_microbatches
else:
dataloader = BatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)

micro_batches_per_mini_batch = (
self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu
)
# The number of steps (over micro batches) to accumulate gradients before taking an optimizer step.
accumulation_steps = micro_batches_per_mini_batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Should the dynamic batch size related packing happen at the batch level or the mini-batch level?

The most efficient one would be at the batch level, but then we won't have a good way to get mini batches afterwards. It seems like we should just do this at the mini batch level first. (you'd need to introduce an explicit loop over mini batches)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also allows us to naturally define the accumulation steps

Copy link
Contributor Author

@justinvyu justinvyu Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the dynamic batch size related packing happen at the batch level or the mini-batch level?

Yeah it should also be done at the global batch level, but I'm leaving that for a follow-up optimization, see the end of the PR description. The global batch level batch balancing is useful to assign each worker a similar number of tokens.

This PR's main goal is the max tokens per microbatch parameter. The balancing of microbatches is kind of a convenient feature add-on that is a slight optimization.

Here's why the balancing at the microbatch level is useful:

  • Worker 0 mini-batch contains 4 sequences of length: [10, 10, 5, 5]
  • Worker 0 has a max microbatch token length of 15.
  • If you just do a single front-to-back pass, you'll end up with 3 microbatches: [10], [10, 5], [5]
  • However, the better way to chunk this would be [10, 5], [10, 5], which is what the balanced microbatching method does here.
  • Basically, we want to minimize number of microbatches while satisfying max_tokens_per_microbatch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also allows us to naturally define the accumulation steps

Could you elaborate on this a bit more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry here's what I meant;

We have "mini batches" per worker that is set based on policy_mini_batch_size. This is separate from the per-worker shard you're referring to.

train_batch_size / policy_mini_batch_size dictates the number of optimization steps with the current training batch.

So for example, with train_batch_size=16, num_workers=8, policy_mini_batch_size=8 , each worker will receive a shard of 2 samples, and this wil be split further into train_batch_size/ policy_mini_batch_size=2 mini batches. Each mini batch will correspond to one optimization step.

For each mini batch per worker, we further split it into micro batches.

My point was that the dynamic batch size related binpacking should happen at the worker-level mini batches not the worker-level overall batch.

erictang000 pushed a commit that referenced this pull request Jan 6, 2026
… two levels (minibatch -> microbatch) (#817)

## Summary

This PR refactors the training loop structure in both `PolicyWorkerBase`
and `CriticWorkerBase` to use a consistent two-level batching strategy
(minibatch → microbatch) with a `microbatch_weight` used for gradient
accumulation.

Refactored ppo_train to use a two-level loop: iterate over minibatches,
then subdivide each minibatch into microbatches
* Changed forward_backward signature from accumulation_steps: int to
microbatch_weight: float
* Loss is now scaled by microbatch_weight (i.e., 1.0 / num_microbatches)
instead of dividing by accumulation_steps
* Optimizer step is now called once per minibatch (after all
microbatches are processed), rather than conditionally based on step
count

## Motivation

The motivation for this PR was for batch balancing which I tried adding
in this PR: #640

The problem was that minibatch boundaries are not explicitly defined
right now. We have 2 configurations `policy_mini_batch_size` and
`micro_train_batch_size_per_gpu`, and the minibatch is implicitly
constructed by doing gradient accumulation over `(policy_mini_batch_size
// micro_train_batch_size_per_gpu)` identical microbatches.
#640 breaks the "same sized
microbatches" assumption, which is why we need to partition explicitly
at the minibatch level first, before chunking into possibly uneven
microbatches. That way, it's simpler to tell when to stop accumulating
the gradient.

Also, the introduction of a more general `microbatch_weight` is also
motivated by dynamic microbatch sizes introduced by
#640. Each microbatch should
contribute `N_i / sum(N_j) * loss_i` to the accumulated gradient. In the
default case, this is just `1/accumulation_steps`.

## Testing

* ✅ [FSDP metrics are equivalent to
master](#817 (comment))
* ✅ [Megatron codepath still
works](#817 (comment))

---------

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu
Copy link
Contributor Author

Closing this PR since it's a bit outdated. Will redo this now that #817 has landed.

@justinvyu justinvyu closed this Jan 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants