Skip to content

Conversation

@justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Dec 31, 2025

Summary

This PR refactors the training loop structure in both PolicyWorkerBase and CriticWorkerBase to use a consistent two-level batching strategy (minibatch → microbatch) with a microbatch_weight used for gradient accumulation.

Refactored ppo_train to use a two-level loop: iterate over minibatches, then subdivide each minibatch into microbatches

  • Changed forward_backward signature from accumulation_steps: int to microbatch_weight: float
  • Loss is now scaled by microbatch_weight (i.e., 1.0 / num_microbatches) instead of dividing by accumulation_steps
  • Optimizer step is now called once per minibatch (after all microbatches are processed), rather than conditionally based on step count

Motivation

The motivation for this PR was for batch balancing which I tried adding in this PR: #640

The problem was that minibatch boundaries are not explicitly defined right now. We have 2 configurations policy_mini_batch_size and micro_train_batch_size_per_gpu, and the minibatch is implicitly constructed by doing gradient accumulation over (policy_mini_batch_size // micro_train_batch_size_per_gpu) identical microbatches. #640 breaks the "same sized microbatches" assumption, which is why we need to partition explicitly at the minibatch level first, before chunking into possibly uneven microbatches. That way, it's simpler to tell when to stop accumulating the gradient.

Also, the introduction of a more general microbatch_weight is also motivated by dynamic microbatch sizes introduced by #640. Each microbatch should contribute N_i / sum(N_j) * loss_i to the accumulated gradient. In the default case, this is just 1/accumulation_steps.

Testing

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the training loop to use a two-level batching strategy (minibatch → microbatch), which is a good architectural improvement for supporting uneven microbatches in the future. The core logic changes are sound, but I've identified a few critical issues. The most significant one is that in both policy and critic training loops, metrics are not being aggregated correctly across microbatches, leading to inaccurate logging. Only the status from the last microbatch of a minibatch is being recorded. Additionally, the optimizer step has been unintentionally removed from the critic's training_step method, which will affect tests relying on it. I've provided detailed comments and suggestions to address these points.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Comment on lines 732 to 738
microbatch_iterator = BatchIterator(
minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)
num_microbatches = len(microbatch_iterator)
microbatch_weight = 1.0 / num_microbatches

for microbatch in microbatch_iterator:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next step is basically just to change this microbatch iterator from one that's doing sample-based chunking to one that's token-based.

@justinvyu justinvyu changed the title [skyrl-train] Refactor training loop structure to explicitly batch at two-levels (minibatch -> microbatch) [skyrl-train] Refactor training loop structure to explicitly batch at two levels (minibatch -> microbatch) Dec 31, 2025
return self

def __next__(self) -> Experience:
def __next__(self) -> TrainingInputBatch:
Copy link
Collaborator

@erictang000 erictang000 Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change also affects the use of BatchIterator for the megatron backend, which implements ppo_train differently FSDP/Deepspeed.

could you make sure the conversion to experience is also handled correctly for the megatron code path? Making sure one of these tests pass:

async def test_megatron_train(

is probably a good way to check this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I added a followup TODO to update the megatron worker's ppo loop as well. Made a minimal change for now to prevent this PR from getting too large.

Copy link
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a comment about megatron + i think the gemini comments are worth a look - essentially we want to make sure that the metrics are aggregated in the same way before and after this PR. Maybe it would be nice to show with running a test before and after this PR (maybe this guy:

def test_ppo_train_basic_execution(ray_init_fixture, cfg, use_entropy_loss, use_kl_loss):
) that this PR doesn't change the metrics

…atch_refactor

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…atch_refactor

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu
Copy link
Contributor Author

To sanity check that this PR didn't introduce any regressions for metrics, I printed the output status from this test: pytest tests/gpu/gpu_ci/test_ppo_train.py::test_gradient_accumulation_scenarios[accumulation_calculation] -s

Printed status on master:

train_status={'final_loss': 0.0018931262311525643, 'policy_loss': -0.0024274957249872386, 'ppo_clip_ratio': 0.0, 'policy_entropy': 5.9921875, 'policy_kl': 10.0, 'policy_lr': 9.999999974752427e-07, 'raw_grad_norm': 0.33137789368629456, 'policy_update_steps': 1.0}, actual_optimizer_steps=1.0

Printed status with this PR:

train_status={'final_loss': 0.0018931262311525643, 'policy_loss': -0.0024274957249872386, 'ppo_clip_ratio': 0.0, 'policy_entropy': 5.9921875, 'policy_kl': 10.0, 'policy_lr': 9.999999974752427e-07, 'raw_grad_norm': 0.33137789368629456, 'policy_update_steps': 1}, actual_optimizer_steps=1

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
@justinvyu
Copy link
Contributor Author

Tested Megatron code path as well:
uv run --isolated --extra dev --extra mcore -- pytest tests/gpu/gpu_ci/test_megatron_worker.py::test_megatron_train[tp2_pp2_policy_seq_packing]


megatron results:  {'final_loss': -0.021974159637466073, 'policy_loss': -0.030495601160509977, 'policy_lr': 9.999999974752427e-07, 'ppo_clip_ratio': 0.0, 'policy_entropy': 8.893153190612793, 'policy_kl': 8.521441221237183, 'raw_grad_norm': 4.0175676345825195, 'policy_update_steps': 1}

fsdp results:  {'final_loss': -0.021487861638888717, 'policy_loss': -0.030021082626679796, 'ppo_clip_ratio': 0.0, 'policy_entropy': 0.5429687462747097, 'policy_kl': 8.533220887184143, 'raw_grad_norm': 3.5146021320670116, 'policy_lr': 9.999999974752427e-07, 'policy_update_steps': 4}

@justinvyu justinvyu requested a review from erictang000 January 6, 2026 00:37
@erictang000
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR refactors the training loops in PolicyWorkerBase and CriticWorkerBase to use a two-level batching structure (minibatch -> microbatch), which is a great improvement for supporting batch balancing and dynamic microbatch sizes. The changes are well-motivated and mostly well-executed.

My feedback focuses on a few key areas:

  • An inconsistency in megatron_worker.py where the new training loop structure has not been applied.
  • A potential bug in the calculation of critic_update_steps.
  • Several opportunities for minor refactoring to improve code clarity and reduce duplication, such as extracting helper methods for status recording and memory snapshotting.
  • Identifying potentially unused code marked with TODO comments that should be cleaned up.

Overall, this is a solid refactoring. Addressing these points will improve the consistency and maintainability of the codebase.

Copy link
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good! I think just one super minor gemini comment, and we need to fix the failing cpu test and we should be good to merge this

cpu test:

=========================== short test summary info ============================
FAILED tests/cpu/test_trainer.py::test_ppo_train_batch_calculations - TypeError: test_ppo_train_batch_calculations.<locals>.mock_policy_forward_backward() got an unexpected keyword argument 'microbatch_weight'
============ 1 failed, 353 passed, 83 warnings in 101.86s (0:01:41) ============


status_mean = reduce_metrics(all_metrics)
status_mean["policy_update_steps"] = policy_update_steps / accumulation_steps
status_mean["policy_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch...

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…atch_refactor

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Copy link
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀🚀🚀

@erictang000 erictang000 merged commit 2a7a572 into NovaSky-AI:main Jan 6, 2026
3 checks passed
@justinvyu justinvyu deleted the minibatch_refactor branch January 6, 2026 20:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants