Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ trainer:
critic_mini_batch_size: 256
micro_train_batch_size_per_gpu: 1
micro_forward_batch_size_per_gpu: 1
max_tokens_per_microbatch: -1 # TODO: Maybe split this between forward and train; -1 means no token-based chunking
update_ref_every_epoch: false
use_sample_packing: true
eval_batch_size: 1024
Expand Down
86 changes: 60 additions & 26 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch
from skyrl_train.distributed.utils import init_custom_process_group
from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl
from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics
from skyrl_train.workers.worker_utils import BatchIterator, TokenBasedBatchIterator, reduce_metrics
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from skyrl_train.workers.worker_utils import BatchIterator, TokenBasedBatchIterator, reduce_metrics
from skyrl_train.workers.worker_utils import BaseBatchIterator, SampleBasedBatchIterator, TokenBasedBatchIterator, reduce_metrics

from skyrl_train.dataset.replay_buffer import Experience
from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch
from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
Expand Down Expand Up @@ -322,16 +322,29 @@ def forward(
) -> TrainingOutputBatch:
"""Run forward pass on the input batch in inference mode.

This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.trainer.micro_forward_batch_size_per_gpu`.
This is a wrapper around `_forward_micro_batch` that runs in micro batches.
Uses token-based chunking if `max_tokens_per_microbatch` is configured, otherwise
falls back to sample-based chunking with `micro_forward_batch_size_per_gpu`.
"""
# run in micro batches of cfg.trainer.micro_forward_batch_size_per_gpu
# TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc.
micro_batches = data.chunk(self.cfg.trainer.micro_forward_batch_size_per_gpu)
# Check if token-based chunking is enabled
if self.cfg.trainer.max_tokens_per_microbatch > 0:
# Use token-based chunking
micro_batch_iterator = TokenBasedBatchIterator(
data=data,
max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch,
)
else:
micro_batch_iterator = BatchIterator(
data=data,
sample_batch_size=self.cfg.trainer.micro_forward_batch_size_per_gpu,
drop_last=False,
)

outputs = []
for micro_batch in micro_batches:
for micro_batch in micro_batch_iterator:
outputs.append(self._forward_micro_batch(micro_batch))
output = TrainingOutputBatch.cat(outputs)

output = micro_batch_iterator.reorder_microbatches(outputs)
if output.device is not None and output.device != torch.device("cpu"):
output = output.to("cpu")
return output
Expand Down Expand Up @@ -640,15 +653,22 @@ def _normalize_mini_batch_size(self):

def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch:
global_step = train_data.metadata["global_step"]
dataloader = BatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)

micro_batches_per_mini_batch = (
self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu
)
# The number of steps (over micro batches) to accumulate gradients before taking an optimizer step.
accumulation_steps = micro_batches_per_mini_batch
if self.cfg.trainer.max_tokens_per_microbatch is not None:
dataloader = TokenBasedBatchIterator(
train_data, max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch
)
accumulation_steps = dataloader.num_microbatches
else:
dataloader = BatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)

micro_batches_per_mini_batch = (
self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu
)
# The number of steps (over micro batches) to accumulate gradients before taking an optimizer step.
accumulation_steps = micro_batches_per_mini_batch
Comment on lines +657 to +671
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Should the dynamic batch size related packing happen at the batch level or the mini-batch level?

The most efficient one would be at the batch level, but then we won't have a good way to get mini batches afterwards. It seems like we should just do this at the mini batch level first. (you'd need to introduce an explicit loop over mini batches)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also allows us to naturally define the accumulation steps

Copy link
Contributor Author

@justinvyu justinvyu Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the dynamic batch size related packing happen at the batch level or the mini-batch level?

Yeah it should also be done at the global batch level, but I'm leaving that for a follow-up optimization, see the end of the PR description. The global batch level batch balancing is useful to assign each worker a similar number of tokens.

This PR's main goal is the max tokens per microbatch parameter. The balancing of microbatches is kind of a convenient feature add-on that is a slight optimization.

Here's why the balancing at the microbatch level is useful:

  • Worker 0 mini-batch contains 4 sequences of length: [10, 10, 5, 5]
  • Worker 0 has a max microbatch token length of 15.
  • If you just do a single front-to-back pass, you'll end up with 3 microbatches: [10], [10, 5], [5]
  • However, the better way to chunk this would be [10, 5], [10, 5], which is what the balanced microbatching method does here.
  • Basically, we want to minimize number of microbatches while satisfying max_tokens_per_microbatch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also allows us to naturally define the accumulation steps

Could you elaborate on this a bit more?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry here's what I meant;

We have "mini batches" per worker that is set based on policy_mini_batch_size. This is separate from the per-worker shard you're referring to.

train_batch_size / policy_mini_batch_size dictates the number of optimization steps with the current training batch.

So for example, with train_batch_size=16, num_workers=8, policy_mini_batch_size=8 , each worker will receive a shard of 2 samples, and this wil be split further into train_batch_size/ policy_mini_batch_size=2 mini batches. Each mini batch will correspond to one optimization step.

For each mini batch per worker, we further split it into micro batches.

My point was that the dynamic batch size related binpacking should happen at the worker-level mini batches not the worker-level overall batch.


status_list = []
all_metrics = defaultdict(list)
Expand All @@ -660,7 +680,9 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch:
desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]",
disable=not self.strategy.is_rank_0(),
)
for local_step, experience in enumerate(pbar):
for local_step, batch in enumerate(pbar):
experience = BatchIterator.batch_to_experience(batch)
print(f"{experience.sequences=}")
status = self.training_step(
experience,
global_step,
Expand Down Expand Up @@ -717,6 +739,7 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch:

# should return an `TrainingOutputBatch`
output = TrainingOutputBatch()
# NOTE: No need to reorder anything here beacuse we average across the entire batch.
output.metadata = {"train_status": status_mean}
return output

Expand Down Expand Up @@ -932,19 +955,29 @@ def save_hf_model(self, export_dir: str, tokenizer):

def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch:
global_step = train_data.metadata["global_step"]
dataloader = BatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)

# TODO: Move this to the base class since it's common to both policy and critic workers.
if self.cfg.trainer.max_tokens_per_microbatch is not None:
dataloader = TokenBasedBatchIterator(
train_data, max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch
)
accumulation_steps = dataloader.num_microbatches
else:
dataloader = BatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)

# TODO: Make `num_microbatches` a property of the dataloader instead of computing it here.
# Ex: see the TokenBasedBatchIterator class.
micro_batches_per_mini_batch = (
self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu
)
# The number of steps (over micro batches) to accumulate gradients before taking an optimizer step.
accumulation_steps = micro_batches_per_mini_batch

torch.cuda.empty_cache()
self.model.train()

micro_batches_per_mini_batch = (
self.critic_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu
)
# The number of steps (over micro batches) to accumulate gradients before taking an optimizer step.
accumulation_steps = micro_batches_per_mini_batch

all_metrics = defaultdict(list)
critic_update_steps = 0
for epoch in range(self.cfg.trainer.update_epochs_per_batch):
Expand All @@ -953,7 +986,8 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch:
desc=f"Critic Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]",
disable=not self.strategy.is_rank_0(),
)
for local_step, experience in enumerate(pbar):
for local_step, batch in enumerate(pbar):
experience = BatchIterator.batch_to_experience(batch)
status = self.training_step(experience, global_step, local_step, accumulation_steps)
critic_update_steps += 1

Expand Down
Loading