Skip to content

feat: dynamic bsz#170

Open
addiaddiaddi wants to merge 15 commits intoNovaSky-AI:mainfrom
addiaddiaddi:dynamic-bsz
Open

feat: dynamic bsz#170
addiaddiaddi wants to merge 15 commits intoNovaSky-AI:mainfrom
addiaddiaddi:dynamic-bsz

Conversation

@addiaddiaddi
Copy link

@addiaddiaddi addiaddiaddi commented Aug 20, 2025

Usage

  • Adjust forward micro_batch size via token amount with trainer.max_token_len_per_gpu_forward
  • Adjust backward micro_batch size via token amount with trainer.max_token_len_per_gpu
  • Enable with trainer.use_dynamic_batching

Details

  • Added dynamic batching functionality to BatchIterator
    • BatchIterator chunks received data by mini_batch_size_per_gpu
    • Each worker calculates required number of micro batches per mini batch
    • All reduce to communicate number of micro batches per mini batch
    • Worker balances mini batches into specified number of micro batches
  • Adjusted metric calculation in ppo_train to be a weighted average

Testing

  • E2e test to compare equivalent loss between fixed batch sizing and dynamic batch sizing
  • Few tests for BatchIterator functionality

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces dynamic token-based batching to improve GPU utilization. The core logic is implemented in the new skyrl_train/utils/dynamic_batching.py file, which uses the Karmarkar-Karp algorithm for balanced partitioning. The BatchIterator in skyrl_train/workers/worker_utils.py has been significantly refactored to support both fixed-size and dynamic batching, encapsulating the batching logic that was previously in the worker's training loops. The training loops in skyrl_train/workers/worker.py have been updated to use the new iterator and handle variable-sized micro-batches with weighted loss accumulation. Comprehensive tests for the new functionality have been added in skyrl-train/tests/gpu/test_dynamic_batching.py.

My review includes suggestions to remove some debugging print statements, clean up unused variables, and fix a few issues, including a typo that would cause a runtime error and a broken test case. Overall, this is a solid implementation of a valuable feature.

Comment on lines 215 to 221
total_sequences = 0
for exp in iterator:
assert hasattr(exp, "sequences"), "Experience should have sequences"
assert hasattr(exp, "attention_mask"), "Experience should have attention_mask"
assert "should_step" in exp.info, "Experience info should have should_step"
assert "accumulation_weight" in exp.info, "Experience info should have accumulation_weight"
total_sequences += exp.sequences.shape[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test appears to be broken. The BatchIterator yields a tuple (experience, should_step), but the loop for exp in iterator: attempts to treat exp as a single Experience object. This will cause an error when trying to access attributes like exp.sequences.

Additionally, the assertions for should_step and accumulation_weight in exp.info are incorrect, as the iterator does not add these to the info dictionary.

The loop should be for experience, should_step in iterator: and the assertions should be adjusted accordingly.

Suggested change
total_sequences = 0
for exp in iterator:
assert hasattr(exp, "sequences"), "Experience should have sequences"
assert hasattr(exp, "attention_mask"), "Experience should have attention_mask"
assert "should_step" in exp.info, "Experience info should have should_step"
assert "accumulation_weight" in exp.info, "Experience info should have accumulation_weight"
total_sequences += exp.sequences.shape[0]
total_sequences = 0
for experience, should_step in iterator:
assert hasattr(experience, "sequences"), "Experience should have sequences"
assert hasattr(experience, "attention_mask"), "Experience should have attention_mask"
assert isinstance(should_step, bool)
total_sequences += experience.sequences.shape[0]

response_mask: Integer[torch.Tensor, "batch_size seq_len"]
action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
rollout_log_probs: Float[torch.Tensor, "batch_size seq_len"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The key rollout_log_probs appears to be duplicated in the TrainingInput TypedDict. Another definition exists further down in the class. Please remove one of them to avoid ambiguity.

rollout_action_logprobs = experience.rollout_logprobs
accumulation_weight = len(experience) / self.policy_mini_batch_size_per_gpu

print(f"Accumulation weight: {accumulation_weight}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes. It should be removed or converted to a logger.debug call before merging.

Comment on lines 763 to 764
logger.info(f"Update: Loss: {loss.item()} | Accumulation weight: {accumulation_weight} | new loss: {loss * accumulation_weight}")
print(f"Update: Loss: {loss.item()} | Accumulation weight: {accumulation_weight} | new loss: {loss * accumulation_weight}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These logger.info and print statements appear to be for debugging. The print statement is redundant. Please consider removing them or converting them to logger.debug calls.

Comment on lines 77 to 80
self._accumulation_weights = []
self._should_step_flags = []
self._micro_batch_sizes = []
self._micro_batch_indices = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The attributes _accumulation_weights, _micro_batch_sizes, and _micro_batch_indices are initialized here and populated later, but they are never used within the BatchIterator class. Please remove them and their associated assignments to simplify the code.

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great. Added a few small comments, mostly just nits. I want to spin this up and run tests as well, so pleas give me some time for that.

By the way, have you had a chance to test with sequence parallelism enabled as well?


dp_group = self.device_mesh["dp"].get_group() if hasattr(self, 'device_mesh') else None

dynamic_bsz = getattr(self.cfg.trainer, 'use_dynamic_batching', False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We've been trying to a) set reasonable defaults in the base training config, and b) doing config validation in utils/utils.py so that we don't have to add these getattrs in-code, which can often surprise people. We could just use self.cfg.trainer.use_dynamic_batching here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good - would you prefer it raise an exception is dynamic batching config is missing, or just put it to false in the utils.py?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, you can get away with doing neither. The base config will always be loaded, so the default should anyways be False. The only exception to this is tests, but it's actually helpful for tests to crash when the necessary config params are passed in.

return num_micro_batches


def balance_partitions_by_compute(partitions: List[List[int]], token_counts: List[int]) -> List[List[int]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I didn't see this used any where -- is there an intended use for it in this PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed


grad_norm = None
if (local_step + 1) % accumulation_steps == 0:
if should_policy_update:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is very clearly improved ergonomics of training_step, thanks :D

from tests.gpu.utils import make_dummy_training_batch, init_worker_with_type, get_test_actor_config


def make_variable_length_training_batch(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: most of the tests in the top-half of the file are cpu-only, so it'd be good to add them in the tests/cpu path so that they are run quickly in CI/CD.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@tyler-griggs
Copy link
Member

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable feature: dynamic batching based on token count. This is an excellent enhancement for improving GPU utilization, particularly with variable sequence lengths. The implementation is thorough, spanning configuration, data handling via a refactored BatchIterator, and updates to the core training loops. The use of the Karmarkar-Karp algorithm for balanced partitioning is a sophisticated and effective choice. The changes are well-supported by a comprehensive new test suite, including a crucial end-to-end test that verifies loss consistency between fixed and dynamic batching. My review includes a few suggestions to enhance efficiency and code clarity in some of the new utility functions and data structures. Overall, this is a high-quality contribution.

Comment on lines 19 to 24
def get_reverse_idx(idx_map: List[int]) -> List[int]:
"""Build the inverse of an index mapping."""
reverse_idx_map = copy.deepcopy(idx_map)
for i, idx in enumerate(idx_map):
reverse_idx_map[idx] = i
return reverse_idx_map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using copy.deepcopy for a list of integers is inefficient. Since you are overwriting the list elements anyway, you can pre-allocate a list of zeros and then fill it. This avoids the overhead of deep copying and is more performant.

Suggested change
def get_reverse_idx(idx_map: List[int]) -> List[int]:
"""Build the inverse of an index mapping."""
reverse_idx_map = copy.deepcopy(idx_map)
for i, idx in enumerate(idx_map):
reverse_idx_map[idx] = i
return reverse_idx_map
def get_reverse_idx(idx_map: List[int]) -> List[int]:
"""Build the inverse of an index mapping."""
reverse_idx_map = [0] * len(idx_map)
for i, idx in enumerate(idx_map):
reverse_idx_map[idx] = i
return reverse_idx_map

Comment on lines 233 to 242
micro_batch_data = {}
for key, value in self.items():
if value is not None:
if isinstance(value, torch.Tensor):
indices = torch.tensor(partition, dtype=torch.long)
micro_batch_data[key] = value[indices]
else:
raise ValueError(f"Unsupported type {type(value)} for key {key}")
else:
micro_batch_data[key] = value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating the indices tensor inside the loop is inefficient, as it will be recreated for every key in the batch. It's better to create it once outside the loop. Additionally, it's good practice to specify the device for the new tensor to match the device of the batch tensors (self.device), which will avoid implicit CPU-to-GPU copies during indexing if the batch is on a GPU.

Suggested change
micro_batch_data = {}
for key, value in self.items():
if value is not None:
if isinstance(value, torch.Tensor):
indices = torch.tensor(partition, dtype=torch.long)
micro_batch_data[key] = value[indices]
else:
raise ValueError(f"Unsupported type {type(value)} for key {key}")
else:
micro_batch_data[key] = value
micro_batch_data = {}
indices = torch.tensor(partition, dtype=torch.long, device=self.device)
for key, value in self.items():
if value is not None:
if isinstance(value, torch.Tensor):
micro_batch_data[key] = value[indices]
else:
raise ValueError(f"Unsupported type {type(value)} for key {key}")
else:
micro_batch_data[key] = value

Comment on lines 201 to 217
"""Return the iterator itself."""
self._reset_iterator()
for i in range(len(self._all_micro_batches)):
if self._current_idx >= len(self._all_micro_batches):
self._reset_iterator()
raise StopIteration

micro_batch = self._all_micro_batches[self._current_idx]

exp = self.batch_to_experience(micro_batch)

if self.dynamic_bsz:
exp.info["micro_batch_utilization"] = micro_batch["attention_mask"].sum().item() / self.max_token_len

self._current_idx += 1

def __next__(self) -> Experience:
try:
batch = next(self._iter)
exp = self.batch_to_experience(batch)
return exp
except StopIteration:
self._iter = iter(self._chunks)
raise StopIteration
yield exp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This generator function can be simplified. The for i in range(...) loop with manual index tracking (self._current_idx) and the StopIteration check are overly complex for a generator. You can directly iterate over self._all_micro_batches, which makes the code cleaner, more Pythonic, and less prone to off-by-one errors.

Suggested change
"""Return the iterator itself."""
self._reset_iterator()
for i in range(len(self._all_micro_batches)):
if self._current_idx >= len(self._all_micro_batches):
self._reset_iterator()
raise StopIteration
micro_batch = self._all_micro_batches[self._current_idx]
exp = self.batch_to_experience(micro_batch)
if self.dynamic_bsz:
exp.info["micro_batch_utilization"] = micro_batch["attention_mask"].sum().item() / self.max_token_len
self._current_idx += 1
def __next__(self) -> Experience:
try:
batch = next(self._iter)
exp = self.batch_to_experience(batch)
return exp
except StopIteration:
self._iter = iter(self._chunks)
raise StopIteration
yield exp
"""Return the iterator itself."""
self._reset_iterator()
for micro_batch in self._all_micro_batches:
exp = self.batch_to_experience(micro_batch)
if self.dynamic_bsz:
exp.info["micro_batch_utilization"] = micro_batch["attention_mask"].sum().item() / self.max_token_len
yield exp

@tyler-griggs tyler-griggs mentioned this pull request Aug 21, 2025
19 tasks
@addiaddiaddi
Copy link
Author

/gemini review

@addiaddiaddi
Copy link
Author

This is great. Added a few small comments, mostly just nits. I want to spin this up and run tests as well, so pleas give me some time for that.

By the way, have you had a chance to test with sequence parallelism enabled as well?

Yes, tested briefly with SP


status_mean = reduce_metrics(all_metrics)
status_mean["policy_update_steps"] = policy_update_steps / accumulation_steps
status_mean = reduce_metrics(all_metrics, dataloader.total_batch_size * self.cfg.trainer.update_epochs_per_batch * self._world_size)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure this is right but worth taking a double click at this to make sure this weighted average works in all cases.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and well-implemented feature: dynamic token-based batching. The use of the Karmarkar-Karp algorithm for creating balanced partitions is a solid approach. The related changes to gradient accumulation, metric calculations, and the new BatchIterator are thorough and correctly handle variable micro-batch sizes. The accompanying tests are comprehensive, particularly the end-to-end test ensuring loss consistency between fixed and dynamic batching, which adds a great deal of confidence in the changes. I've identified a critical issue with a duplicate key in a TypedDict that will cause a SyntaxError, along with a few medium-severity suggestions to clean up unused code and improve type correctness. Overall, this is a high-quality contribution.

response_mask: Integer[torch.Tensor, "batch_size seq_len"]
action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
rollout_log_probs: Float[torch.Tensor, "batch_size seq_len"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The key rollout_log_probs is defined twice in the TrainingInput TypedDict (here and on line 357, where it seems to be a typo rollout_logprobs). This will raise a SyntaxError at runtime because of the duplicate key. Please remove one of the definitions to resolve the error. Given that the rest of the code uses rollout_log_probs, it's likely the entry on line 357 is a typo and this one is the intended duplicate.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^^ Yes please

adapted from the reference implementation to work with skyrl-train's architecture.
"""

import copy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The copy module is imported but not used in this file. It should be removed to keep the imports clean.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove copy

Comment on lines 19 to 24
def get_reverse_idx(idx_map: List[int]) -> List[int]:
"""Build the inverse of an index mapping."""
reverse_idx_map = [0] * len(idx_map)
for i, idx in enumerate(idx_map):
reverse_idx_map[idx] = i
return reverse_idx_map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_reverse_idx is defined but is not used anywhere in the project. To maintain code clarity and remove dead code, it should be removed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove get_reverse_idx

worker_type: Type of worker ("policy" or "critic")
dp_group: Distributed process group for synchronization
dynamic_bsz: Whether to use dynamic token-based batching
mini_batch_size_per_gpu: Override mini-batch size per GPU
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring for BatchIterator.__init__ lists mini_batch_size_per_gpu as a parameter in the Args section, but it is not actually a parameter of the function. This should be removed from the docstring to avoid confusion.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove mini_batch_size_per_gpu

"""Synchronize micro-batch counts across distributed workers."""
if self.dp_group is not None and dist.is_initialized():
local_copy = local_counts.copy()
counts_tensor = torch.tensor(local_counts, dtype=torch.float32, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The counts_tensor is created with dtype=torch.float32 to store integer counts of micro-batches. While this may work for small numbers, using floating-point numbers for counts can introduce precision issues. It's more robust to use an integer type like torch.int64 for counts, which is also supported by dist.all_reduce.

Suggested change
counts_tensor = torch.tensor(local_counts, dtype=torch.float32, device="cuda")
counts_tensor = torch.tensor(local_counts, dtype=torch.int64, device="cuda")

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making a brief pass, I hit a couple errors when first launching training (e.g., the missing mini_batch_size_per_gpu field in BatchIterator)


dp_group = self.device_mesh["dp"].get_group() if hasattr(self, 'device_mesh') else None

dynamic_bsz = getattr(self.cfg.trainer, 'use_dynamic_batching', False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, you can get away with doing neither. The base config will always be loaded, so the default should anyways be False. The only exception to this is tests, but it's actually helpful for tests to crash when the necessary config params are passed in.

# Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl
tis_imp_ratio_cap: -1.0
use_tis: false
max_seq_len: 512
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used? Or can it be removed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I was confused by this because it is needed here but wasn't already in the config.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, I see that this is added in config validation. will just hardcode add it to the test

adapted from the reference implementation to work with skyrl-train's architecture.
"""

import copy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove copy

Comment on lines 19 to 24
def get_reverse_idx(idx_map: List[int]) -> List[int]:
"""Build the inverse of an index mapping."""
reverse_idx_map = [0] * len(idx_map)
for i, idx in enumerate(idx_map):
reverse_idx_map[idx] = i
return reverse_idx_map
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove get_reverse_idx

worker_type: Type of worker ("policy" or "critic")
dp_group: Distributed process group for synchronization
dynamic_bsz: Whether to use dynamic token-based batching
mini_batch_size_per_gpu: Override mini-batch size per GPU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove mini_batch_size_per_gpu

response_mask: Integer[torch.Tensor, "batch_size seq_len"]
action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"]
rollout_log_probs: Float[torch.Tensor, "batch_size seq_len"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^^ Yes please

@addiaddiaddi
Copy link
Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable feature: dynamic token-based batching. The goal is to improve GPU utilization by creating micro-batches based on the number of tokens rather than a fixed number of sequences. The implementation is comprehensive, touching the configuration, data structures, training loop, and adding a new BatchIterator with the core logic. The changes also include a sophisticated partitioning algorithm (Karmarkar-Karp) to create balanced batches. The inclusion of extensive unit and end-to-end tests, especially the one for loss consistency, is commendable and provides high confidence in the correctness of this complex change. I've identified a couple of critical issues, including a typo in an argument name and an incorrect attribute assignment that would lead to runtime errors, along with some suggestions for improving maintainability and robustness. Overall, this is a well-executed feature addition.

addiaddiaddi and others added 3 commits August 22, 2025 18:32
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@addiaddiaddi
Copy link
Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature: dynamic token-based batching to improve GPU utilization. The implementation includes a new dynamic_batching.py module with the Karmarkar-Karp partitioning algorithm and a major refactoring of the BatchIterator. The worker logic is updated to handle variable-sized batches for gradient accumulation and metric calculation. The changes are well-supported by extensive new tests, including an end-to-end test for loss consistency, which is excellent.

I've found one critical issue and one of medium severity. The critical issue is a bug in the CriticWorkerBase.ppo_train method where the BatchIterator is initialized with incorrect arguments, which will cause a runtime error. The medium severity issue is a confusing error message in an assertion. Once these are addressed, the PR should be in great shape.

Comment on lines 919 to 927
cfg=self.cfg,
fg=self.cfg,
dp_size=self.mesh_rank.dp_size,
drop_last=False,
worker_type="critic",
dp_group=dp_group,
dynamic_bsz=self.cfg.trainer.use_dynamic_batching
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The BatchIterator is being initialized incorrectly in CriticWorkerBase.ppo_train. The required positional argument train_data is missing, and an unexpected keyword argument fg is provided. This will cause a TypeError at runtime.

Suggested change
cfg=self.cfg,
fg=self.cfg,
dp_size=self.mesh_rank.dp_size,
drop_last=False,
worker_type="critic",
dp_group=dp_group,
dynamic_bsz=self.cfg.trainer.use_dynamic_batching
)
dataloader = BatchIterator(
train_data,
cfg=self.cfg,
dp_size=self.mesh_rank.dp_size,
drop_last=False,
worker_type="critic",
dp_group=dp_group,
dynamic_bsz=self.cfg.trainer.use_dynamic_batching
)

)
if not dynamic_bsz:
assert dataloader.micro_batches_per_mini_batch <= len(dataloader), f"Accumulation Steps {dataloader.micro_batches_per_mini_batch} cannot be less than number of micro batches in total {len(dataloader)}"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion logic A <= B is correct, but the error message f"A cannot be less than B" is inverted. It should state that A cannot be greater than B to avoid confusion during debugging.

Suggested change
assert dataloader.micro_batches_per_mini_batch <= len(dataloader), f"Accumulation Steps {dataloader.micro_batches_per_mini_batch} cannot be greater than number of micro batches in total {len(dataloader)}"

@addiaddiaddi
Copy link
Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant feature: dynamic token-based batching to improve GPU utilization. The implementation leverages the Karmarkar-Karp algorithm for balanced partitioning and refactors the BatchIterator to handle both fixed and dynamic batching modes. The changes are extensive, touching configuration, data structures, worker training loops, and adding a comprehensive suite of new tests, including an important end-to-end loss consistency test. Overall, this is a well-structured enhancement. I've identified a couple of critical issues that would lead to runtime errors, which should be addressed.

Comment on lines 918 to 927
dataloader = BatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
train_data,
cfg=self.cfg,
fg=self.cfg,
dp_size=self.mesh_rank.dp_size,
drop_last=False,
worker_type="critic",
dp_group=dp_group,
dynamic_bsz=self.cfg.trainer.use_dynamic_batching
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's an invalid keyword argument fg being passed to the BatchIterator constructor. This will raise a TypeError at runtime as BatchIterator does not accept an fg parameter. This appears to be a typo and should be removed.

        dataloader = BatchIterator(
            train_data,
            cfg=self.cfg,
            dp_size=self.mesh_rank.dp_size,
            drop_last=False,
            worker_type="critic",
            dp_group=dp_group,
            dynamic_bsz=self.cfg.trainer.use_dynamic_batching
        )

@tyler-griggs
Copy link
Member

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable feature: dynamic batching based on token count. This should improve GPU utilization by packing batches more efficiently. The implementation is well-structured, with new utilities for balanced partitioning and a refactored BatchIterator. The addition of comprehensive unit and integration tests is excellent. However, I've found a critical issue in the gradient accumulation logic for both the policy and critic trainers that could lead to skipped optimizer steps. The fix is straightforward and I've provided suggestions.

Comment on lines 637 to 648
samples_seen += len(experience)

should_policy_update = (samples_seen % self.policy_mini_batch_size_per_gpu == 0)

status = self.training_step(
experience,
global_step,
local_step,
accumulation_steps,
should_policy_update
)
policy_update_steps += 1
if should_policy_update:
policy_update_steps += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for determining when to perform an optimizer step (should_policy_update) is flawed. The use of samples_seen % self.policy_mini_batch_size_per_gpu == 0 is not robust for variable-sized micro-batches from dynamic batching, as samples_seen may never be an exact multiple of the mini-batch size, causing optimizer steps to be skipped. Additionally, samples_seen accumulates over the entire epoch without being reset after an optimizer step, which is incorrect for mini-batch-based gradient accumulation.

A more robust approach is to use a counter that tracks samples since the last update, triggers an update when the threshold is met, and then resets.

Suggested change
samples_seen += len(experience)
should_policy_update = (samples_seen % self.policy_mini_batch_size_per_gpu == 0)
status = self.training_step(
experience,
global_step,
local_step,
accumulation_steps,
should_policy_update
)
policy_update_steps += 1
if should_policy_update:
policy_update_steps += 1
samples_seen += len(experience)
should_policy_update = samples_seen >= self.policy_mini_batch_size_per_gpu
status = self.training_step(
experience,
global_step,
local_step,
should_policy_update
)
if should_policy_update:
policy_update_steps += 1
samples_seen = 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect, right? Since we first split the training batch into mini batches, then split each mini batch into micro batches, it is guaranteed that we always sum up to the mini batch size.

Comment on lines 944 to 949
samples_seen += len(experience)
should_critic_update = (samples_seen % self.critic_mini_batch_size_per_gpu == 0)
status = self.training_step(experience, global_step, local_step, should_critic_update)

if should_critic_update:
critic_update_steps += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for determining should_critic_update has the same flaw as in the policy worker. Using the modulo operator (%) is not safe for variable-sized micro-batches, and the samples_seen counter is not reset after an update. This can lead to skipped optimizer steps.

The logic should be to trigger an update when the accumulated sample count meets or exceeds the mini-batch threshold, and then reset the counter.

Suggested change
samples_seen += len(experience)
should_critic_update = (samples_seen % self.critic_mini_batch_size_per_gpu == 0)
status = self.training_step(experience, global_step, local_step, should_critic_update)
if should_critic_update:
critic_update_steps += 1
samples_seen += len(experience)
should_critic_update = (samples_seen >= self.critic_mini_batch_size_per_gpu)
status = self.training_step(experience, global_step, local_step, should_critic_update)
if should_critic_update:
critic_update_steps += 1
samples_seen = 0

Comment on lines +94 to +97
raise ValueError(
f"mini_batch_size_per_gpu must be positive, got {self.mini_batch_size_per_gpu}. "
f"Check your configuration: worker_type={self.worker_type}, dp_size={self.dp_size}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for calculating mini_batch_size_per_gpu can result in a value of 0 if base_mini_batch_size * self.cfg.generator.n_samples_per_prompt is less than self.dp_size. While there is a check for this that raises a ValueError, the error message could be more informative by including the values of base_mini_batch_size and n_samples_per_prompt to help with debugging configuration issues.

Suggested change
raise ValueError(
f"mini_batch_size_per_gpu must be positive, got {self.mini_batch_size_per_gpu}. "
f"Check your configuration: worker_type={self.worker_type}, dp_size={self.dp_size}"
)
raise ValueError(
f"mini_batch_size_per_gpu must be positive, got {self.mini_batch_size_per_gpu}. "
f"Check your configuration: base_mini_batch_size={base_mini_batch_size}, "
f"n_samples_per_prompt={self.cfg.generator.n_samples_per_prompt}, dp_size={self.dp_size}"
)

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving a handful of nits. I'm sanity checking the weighted average now, then should be done!

action_log_probs=batch["action_log_probs"],
base_action_log_probs=batch["base_action_log_probs"],
values=batch["values"],
rollout_logprobs=batch["rollout_log_probs"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this key has an extra _ and should be rollout_logprobs

micro_train_batch_size_per_gpu: 1
micro_forward_batch_size_per_gpu: 1
use_dynamic_batching: false # Enable dynamic token-based batching instead of fixed batch sizes
max_token_len_per_gpu: 4096 # Maximum tokens per micro-batch when dynamic batching is enabled
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we suffix this with _train to mirror the micro_train and micro_forward params above?

self._chunks = self.data.chunk(self.sample_batch_size)
self._iter = iter(self._chunks)

logger.info(f"Total batch size: {self.total_batch_size}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Is this helpful to see for you during training? I would recommend either removing it, or adding more detail so that it's clear what it means. In the training log, if I just see "Total batch size", that doesn't really give me enough information to know what's going on :D


logger.info(f"Total batch size: {self.total_batch_size}")

logger.info(f"Sizes: {[len(seq) for seq in data["sequences"]]} {dp_size=}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is this helpful for you to see during training? It's a bit noisy in my logs because it can be quite a long list.

synced_counts = [int(x) for x in counts_tensor.tolist()]

logger.info(
f"[Rank {dist.get_rank()}] BatchIterator sync - " f"Local counts: {local_copy}, Synced: {synced_counts}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: consider removing this, or at least making this a debug statement?

return synced_counts
else:
if self.dynamic_bsz:
logger.info(f"BatchIterator - No distributed sync, using local: {local_counts}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto the above

accumulation_steps = micro_batches_per_mini_batch
if dynamic_bsz:
logger.info(
f"Data {len(train_data)} | Dynamic Batching | Max Tokens per GPU {self.cfg.trainer.max_token_len_per_gpu} | Num Micro Batches {len(dataloader)}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: can you change Data to Training Batch Size so that it's more clear in the logs what this is?

@tyler-griggs
Copy link
Member

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and valuable feature: dynamic batching based on token counts. This should improve GPU utilization by packing batches more efficiently. The implementation is well-structured, with the core logic encapsulated in skyrl_train/utils/dynamic_batching.py and integrated into a refactored BatchIterator. The addition of comprehensive unit and end-to-end tests, especially the loss consistency test, is commendable and crucial for such a complex change.

However, I've identified a critical issue in the training loop's logic for triggering optimizer steps in both the policy and critic workers. The current implementation can lead to missed gradient updates, especially with variable-sized micro-batches. I've provided detailed comments and suggestions to address this. Once this is fixed, this will be an excellent addition to the codebase.

Copy link
Member

@tyler-griggs tyler-griggs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last handful of nits, then I think we're ready to merge!

def _calculate_dynamic_micro_batch_counts(self, mini_batches: List[TrainingInputBatch]) -> List[int]:
"""Calculate the number of micro-batches needed for each mini-batch based on token counts."""
if self.for_inference:
self.max_token_len = getattr(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is another case where you can assume the config value exists and you don't need to use getattr

Comment on lines 637 to 648
samples_seen += len(experience)

should_policy_update = (samples_seen % self.policy_mini_batch_size_per_gpu == 0)

status = self.training_step(
experience,
global_step,
local_step,
accumulation_steps,
should_policy_update
)
policy_update_steps += 1
if should_policy_update:
policy_update_steps += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is incorrect, right? Since we first split the training batch into mini batches, then split each mini batch into micro batches, it is guaranteed that we always sum up to the mini batch size.

else:
if self.dynamic_bsz:
logger.info(f"BatchIterator - No distributed sync, using local: {local_counts}")
logger.info(f"[DEBUG] BatchIterator - No distributed sync, using local: {local_counts}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I actually meant this should be logger.debug so that we can optionally filter it out by setting the log level to info and avoid the debug statements.

# We assume that outputs are replicated within tp or sp group, otherwise this is not correct.
status = self.strategy.all_reduce(status)
for k in status:
status[k] = len(experience) * status[k]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To sanity check my understanding: this is now the sum total of all metrics across all DP and SP workers and summed across all samples (by weighting each sample by its micro-batch's final metrics), so there is technically a lot of overcounting. But, when we later call reduce_metrics we divide by the total number of samples (accounting for multiple epoch passes over the batch and for world_size), so it will take the avg across all of the samples (weighted by their micro batch's metrics). So then it all works out in the end. Does that match your understanding?

@staticmethod
def batch_to_experience(batch: TrainingInputBatch):
# TODO (sumanthrh): other keys are not permitted right now, can go into info
# TODO: this conversion is hidden right now, might need to be surfaced in worker explicitly.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please keep these comments? Same for the ones below (e.g., additional info, metadata)

- Multi-GPU distributed training produces consistent results
- Full e2e pipeline with inference engines and generator input
"""
# import asyncio
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: delete these?

shared_pg=None,
gpu_memory_utilization=0.8,
inference_engine_enable_sleep=True,
inference_engine_enable_sleep=inference_engine_enable_sleep, # Use parameter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: delete comment?

max_turns=1,
use_conversation_multi_turn=True,
max_env_workers=10,
inference_engine_enable_sleep=True, # Add as parameter with default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: delete comment?

@tyler-griggs
Copy link
Member

Hm, okay I launched a run comparing dynamic batch sizes to main, and I'm getting some weird results (both in the metrics, and the actual rewards). The blue/green line is from running at main, and the gold line is from this PR. I think metric calculation needs some debugging:

Screenshot 2025-08-24 at 3 00 44 PM Screenshot 2025-08-24 at 3 03 14 PM Screenshot 2025-08-24 at 3 00 47 PM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments