Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces dynamic token-based batching to improve GPU utilization. The core logic is implemented in the new skyrl_train/utils/dynamic_batching.py file, which uses the Karmarkar-Karp algorithm for balanced partitioning. The BatchIterator in skyrl_train/workers/worker_utils.py has been significantly refactored to support both fixed-size and dynamic batching, encapsulating the batching logic that was previously in the worker's training loops. The training loops in skyrl_train/workers/worker.py have been updated to use the new iterator and handle variable-sized micro-batches with weighted loss accumulation. Comprehensive tests for the new functionality have been added in skyrl-train/tests/gpu/test_dynamic_batching.py.
My review includes suggestions to remove some debugging print statements, clean up unused variables, and fix a few issues, including a typo that would cause a runtime error and a broken test case. Overall, this is a solid implementation of a valuable feature.
| total_sequences = 0 | ||
| for exp in iterator: | ||
| assert hasattr(exp, "sequences"), "Experience should have sequences" | ||
| assert hasattr(exp, "attention_mask"), "Experience should have attention_mask" | ||
| assert "should_step" in exp.info, "Experience info should have should_step" | ||
| assert "accumulation_weight" in exp.info, "Experience info should have accumulation_weight" | ||
| total_sequences += exp.sequences.shape[0] |
There was a problem hiding this comment.
This test appears to be broken. The BatchIterator yields a tuple (experience, should_step), but the loop for exp in iterator: attempts to treat exp as a single Experience object. This will cause an error when trying to access attributes like exp.sequences.
Additionally, the assertions for should_step and accumulation_weight in exp.info are incorrect, as the iterator does not add these to the info dictionary.
The loop should be for experience, should_step in iterator: and the assertions should be adjusted accordingly.
| total_sequences = 0 | |
| for exp in iterator: | |
| assert hasattr(exp, "sequences"), "Experience should have sequences" | |
| assert hasattr(exp, "attention_mask"), "Experience should have attention_mask" | |
| assert "should_step" in exp.info, "Experience info should have should_step" | |
| assert "accumulation_weight" in exp.info, "Experience info should have accumulation_weight" | |
| total_sequences += exp.sequences.shape[0] | |
| total_sequences = 0 | |
| for experience, should_step in iterator: | |
| assert hasattr(experience, "sequences"), "Experience should have sequences" | |
| assert hasattr(experience, "attention_mask"), "Experience should have attention_mask" | |
| assert isinstance(should_step, bool) | |
| total_sequences += experience.sequences.shape[0] |
| response_mask: Integer[torch.Tensor, "batch_size seq_len"] | ||
| action_log_probs: Float[torch.Tensor, "batch_size seq_len"] | ||
| base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"] | ||
| rollout_log_probs: Float[torch.Tensor, "batch_size seq_len"] |
| rollout_action_logprobs = experience.rollout_logprobs | ||
| accumulation_weight = len(experience) / self.policy_mini_batch_size_per_gpu | ||
|
|
||
| print(f"Accumulation weight: {accumulation_weight}") |
| logger.info(f"Update: Loss: {loss.item()} | Accumulation weight: {accumulation_weight} | new loss: {loss * accumulation_weight}") | ||
| print(f"Update: Loss: {loss.item()} | Accumulation weight: {accumulation_weight} | new loss: {loss * accumulation_weight}") |
| self._accumulation_weights = [] | ||
| self._should_step_flags = [] | ||
| self._micro_batch_sizes = [] | ||
| self._micro_batch_indices = [] |
tyler-griggs
left a comment
There was a problem hiding this comment.
This is great. Added a few small comments, mostly just nits. I want to spin this up and run tests as well, so pleas give me some time for that.
By the way, have you had a chance to test with sequence parallelism enabled as well?
|
|
||
| dp_group = self.device_mesh["dp"].get_group() if hasattr(self, 'device_mesh') else None | ||
|
|
||
| dynamic_bsz = getattr(self.cfg.trainer, 'use_dynamic_batching', False) |
There was a problem hiding this comment.
nit: We've been trying to a) set reasonable defaults in the base training config, and b) doing config validation in utils/utils.py so that we don't have to add these getattrs in-code, which can often surprise people. We could just use self.cfg.trainer.use_dynamic_batching here.
There was a problem hiding this comment.
Sounds good - would you prefer it raise an exception is dynamic batching config is missing, or just put it to false in the utils.py?
There was a problem hiding this comment.
In this case, you can get away with doing neither. The base config will always be loaded, so the default should anyways be False. The only exception to this is tests, but it's actually helpful for tests to crash when the necessary config params are passed in.
| return num_micro_batches | ||
|
|
||
|
|
||
| def balance_partitions_by_compute(partitions: List[List[int]], token_counts: List[int]) -> List[List[int]]: |
There was a problem hiding this comment.
nit: I didn't see this used any where -- is there an intended use for it in this PR?
|
|
||
| grad_norm = None | ||
| if (local_step + 1) % accumulation_steps == 0: | ||
| if should_policy_update: |
There was a problem hiding this comment.
Yeah this is very clearly improved ergonomics of training_step, thanks :D
| from tests.gpu.utils import make_dummy_training_batch, init_worker_with_type, get_test_actor_config | ||
|
|
||
|
|
||
| def make_variable_length_training_batch( |
There was a problem hiding this comment.
nit: most of the tests in the top-half of the file are cpu-only, so it'd be good to add them in the tests/cpu path so that they are run quickly in CI/CD.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and valuable feature: dynamic batching based on token count. This is an excellent enhancement for improving GPU utilization, particularly with variable sequence lengths. The implementation is thorough, spanning configuration, data handling via a refactored BatchIterator, and updates to the core training loops. The use of the Karmarkar-Karp algorithm for balanced partitioning is a sophisticated and effective choice. The changes are well-supported by a comprehensive new test suite, including a crucial end-to-end test that verifies loss consistency between fixed and dynamic batching. My review includes a few suggestions to enhance efficiency and code clarity in some of the new utility functions and data structures. Overall, this is a high-quality contribution.
| def get_reverse_idx(idx_map: List[int]) -> List[int]: | ||
| """Build the inverse of an index mapping.""" | ||
| reverse_idx_map = copy.deepcopy(idx_map) | ||
| for i, idx in enumerate(idx_map): | ||
| reverse_idx_map[idx] = i | ||
| return reverse_idx_map |
There was a problem hiding this comment.
Using copy.deepcopy for a list of integers is inefficient. Since you are overwriting the list elements anyway, you can pre-allocate a list of zeros and then fill it. This avoids the overhead of deep copying and is more performant.
| def get_reverse_idx(idx_map: List[int]) -> List[int]: | |
| """Build the inverse of an index mapping.""" | |
| reverse_idx_map = copy.deepcopy(idx_map) | |
| for i, idx in enumerate(idx_map): | |
| reverse_idx_map[idx] = i | |
| return reverse_idx_map | |
| def get_reverse_idx(idx_map: List[int]) -> List[int]: | |
| """Build the inverse of an index mapping.""" | |
| reverse_idx_map = [0] * len(idx_map) | |
| for i, idx in enumerate(idx_map): | |
| reverse_idx_map[idx] = i | |
| return reverse_idx_map |
| micro_batch_data = {} | ||
| for key, value in self.items(): | ||
| if value is not None: | ||
| if isinstance(value, torch.Tensor): | ||
| indices = torch.tensor(partition, dtype=torch.long) | ||
| micro_batch_data[key] = value[indices] | ||
| else: | ||
| raise ValueError(f"Unsupported type {type(value)} for key {key}") | ||
| else: | ||
| micro_batch_data[key] = value |
There was a problem hiding this comment.
Creating the indices tensor inside the loop is inefficient, as it will be recreated for every key in the batch. It's better to create it once outside the loop. Additionally, it's good practice to specify the device for the new tensor to match the device of the batch tensors (self.device), which will avoid implicit CPU-to-GPU copies during indexing if the batch is on a GPU.
| micro_batch_data = {} | |
| for key, value in self.items(): | |
| if value is not None: | |
| if isinstance(value, torch.Tensor): | |
| indices = torch.tensor(partition, dtype=torch.long) | |
| micro_batch_data[key] = value[indices] | |
| else: | |
| raise ValueError(f"Unsupported type {type(value)} for key {key}") | |
| else: | |
| micro_batch_data[key] = value | |
| micro_batch_data = {} | |
| indices = torch.tensor(partition, dtype=torch.long, device=self.device) | |
| for key, value in self.items(): | |
| if value is not None: | |
| if isinstance(value, torch.Tensor): | |
| micro_batch_data[key] = value[indices] | |
| else: | |
| raise ValueError(f"Unsupported type {type(value)} for key {key}") | |
| else: | |
| micro_batch_data[key] = value |
| """Return the iterator itself.""" | ||
| self._reset_iterator() | ||
| for i in range(len(self._all_micro_batches)): | ||
| if self._current_idx >= len(self._all_micro_batches): | ||
| self._reset_iterator() | ||
| raise StopIteration | ||
|
|
||
| micro_batch = self._all_micro_batches[self._current_idx] | ||
|
|
||
| exp = self.batch_to_experience(micro_batch) | ||
|
|
||
| if self.dynamic_bsz: | ||
| exp.info["micro_batch_utilization"] = micro_batch["attention_mask"].sum().item() / self.max_token_len | ||
|
|
||
| self._current_idx += 1 | ||
|
|
||
| def __next__(self) -> Experience: | ||
| try: | ||
| batch = next(self._iter) | ||
| exp = self.batch_to_experience(batch) | ||
| return exp | ||
| except StopIteration: | ||
| self._iter = iter(self._chunks) | ||
| raise StopIteration | ||
| yield exp |
There was a problem hiding this comment.
This generator function can be simplified. The for i in range(...) loop with manual index tracking (self._current_idx) and the StopIteration check are overly complex for a generator. You can directly iterate over self._all_micro_batches, which makes the code cleaner, more Pythonic, and less prone to off-by-one errors.
| """Return the iterator itself.""" | |
| self._reset_iterator() | |
| for i in range(len(self._all_micro_batches)): | |
| if self._current_idx >= len(self._all_micro_batches): | |
| self._reset_iterator() | |
| raise StopIteration | |
| micro_batch = self._all_micro_batches[self._current_idx] | |
| exp = self.batch_to_experience(micro_batch) | |
| if self.dynamic_bsz: | |
| exp.info["micro_batch_utilization"] = micro_batch["attention_mask"].sum().item() / self.max_token_len | |
| self._current_idx += 1 | |
| def __next__(self) -> Experience: | |
| try: | |
| batch = next(self._iter) | |
| exp = self.batch_to_experience(batch) | |
| return exp | |
| except StopIteration: | |
| self._iter = iter(self._chunks) | |
| raise StopIteration | |
| yield exp | |
| """Return the iterator itself.""" | |
| self._reset_iterator() | |
| for micro_batch in self._all_micro_batches: | |
| exp = self.batch_to_experience(micro_batch) | |
| if self.dynamic_bsz: | |
| exp.info["micro_batch_utilization"] = micro_batch["attention_mask"].sum().item() / self.max_token_len | |
| yield exp |
|
/gemini review |
Yes, tested briefly with SP |
|
|
||
| status_mean = reduce_metrics(all_metrics) | ||
| status_mean["policy_update_steps"] = policy_update_steps / accumulation_steps | ||
| status_mean = reduce_metrics(all_metrics, dataloader.total_batch_size * self.cfg.trainer.update_epochs_per_batch * self._world_size) |
There was a problem hiding this comment.
Pretty sure this is right but worth taking a double click at this to make sure this weighted average works in all cases.
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-implemented feature: dynamic token-based batching. The use of the Karmarkar-Karp algorithm for creating balanced partitions is a solid approach. The related changes to gradient accumulation, metric calculations, and the new BatchIterator are thorough and correctly handle variable micro-batch sizes. The accompanying tests are comprehensive, particularly the end-to-end test ensuring loss consistency between fixed and dynamic batching, which adds a great deal of confidence in the changes. I've identified a critical issue with a duplicate key in a TypedDict that will cause a SyntaxError, along with a few medium-severity suggestions to clean up unused code and improve type correctness. Overall, this is a high-quality contribution.
| response_mask: Integer[torch.Tensor, "batch_size seq_len"] | ||
| action_log_probs: Float[torch.Tensor, "batch_size seq_len"] | ||
| base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"] | ||
| rollout_log_probs: Float[torch.Tensor, "batch_size seq_len"] |
There was a problem hiding this comment.
The key rollout_log_probs is defined twice in the TrainingInput TypedDict (here and on line 357, where it seems to be a typo rollout_logprobs). This will raise a SyntaxError at runtime because of the duplicate key. Please remove one of the definitions to resolve the error. Given that the rest of the code uses rollout_log_probs, it's likely the entry on line 357 is a typo and this one is the intended duplicate.
| adapted from the reference implementation to work with skyrl-train's architecture. | ||
| """ | ||
|
|
||
| import copy |
| def get_reverse_idx(idx_map: List[int]) -> List[int]: | ||
| """Build the inverse of an index mapping.""" | ||
| reverse_idx_map = [0] * len(idx_map) | ||
| for i, idx in enumerate(idx_map): | ||
| reverse_idx_map[idx] = i | ||
| return reverse_idx_map |
There was a problem hiding this comment.
nit: remove get_reverse_idx
| worker_type: Type of worker ("policy" or "critic") | ||
| dp_group: Distributed process group for synchronization | ||
| dynamic_bsz: Whether to use dynamic token-based batching | ||
| mini_batch_size_per_gpu: Override mini-batch size per GPU |
There was a problem hiding this comment.
nit: remove mini_batch_size_per_gpu
| """Synchronize micro-batch counts across distributed workers.""" | ||
| if self.dp_group is not None and dist.is_initialized(): | ||
| local_copy = local_counts.copy() | ||
| counts_tensor = torch.tensor(local_counts, dtype=torch.float32, device="cuda") |
There was a problem hiding this comment.
The counts_tensor is created with dtype=torch.float32 to store integer counts of micro-batches. While this may work for small numbers, using floating-point numbers for counts can introduce precision issues. It's more robust to use an integer type like torch.int64 for counts, which is also supported by dist.all_reduce.
| counts_tensor = torch.tensor(local_counts, dtype=torch.float32, device="cuda") | |
| counts_tensor = torch.tensor(local_counts, dtype=torch.int64, device="cuda") |
tyler-griggs
left a comment
There was a problem hiding this comment.
Making a brief pass, I hit a couple errors when first launching training (e.g., the missing mini_batch_size_per_gpu field in BatchIterator)
|
|
||
| dp_group = self.device_mesh["dp"].get_group() if hasattr(self, 'device_mesh') else None | ||
|
|
||
| dynamic_bsz = getattr(self.cfg.trainer, 'use_dynamic_batching', False) |
There was a problem hiding this comment.
In this case, you can get away with doing neither. The base config will always be loaded, so the default should anyways be False. The only exception to this is tests, but it's actually helpful for tests to crash when the necessary config params are passed in.
| # Truncated Importance Sampling as proposed in https://fengyao.notion.site/off-policy-rl | ||
| tis_imp_ratio_cap: -1.0 | ||
| use_tis: false | ||
| max_seq_len: 512 |
There was a problem hiding this comment.
Is this used? Or can it be removed?
There was a problem hiding this comment.
Yea, I was confused by this because it is needed here but wasn't already in the config.
There was a problem hiding this comment.
Nevermind, I see that this is added in config validation. will just hardcode add it to the test
| adapted from the reference implementation to work with skyrl-train's architecture. | ||
| """ | ||
|
|
||
| import copy |
| def get_reverse_idx(idx_map: List[int]) -> List[int]: | ||
| """Build the inverse of an index mapping.""" | ||
| reverse_idx_map = [0] * len(idx_map) | ||
| for i, idx in enumerate(idx_map): | ||
| reverse_idx_map[idx] = i | ||
| return reverse_idx_map |
There was a problem hiding this comment.
nit: remove get_reverse_idx
| worker_type: Type of worker ("policy" or "critic") | ||
| dp_group: Distributed process group for synchronization | ||
| dynamic_bsz: Whether to use dynamic token-based batching | ||
| mini_batch_size_per_gpu: Override mini-batch size per GPU |
There was a problem hiding this comment.
nit: remove mini_batch_size_per_gpu
| response_mask: Integer[torch.Tensor, "batch_size seq_len"] | ||
| action_log_probs: Float[torch.Tensor, "batch_size seq_len"] | ||
| base_action_log_probs: Float[torch.Tensor, "batch_size seq_len"] | ||
| rollout_log_probs: Float[torch.Tensor, "batch_size seq_len"] |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and valuable feature: dynamic token-based batching. The goal is to improve GPU utilization by creating micro-batches based on the number of tokens rather than a fixed number of sequences. The implementation is comprehensive, touching the configuration, data structures, training loop, and adding a new BatchIterator with the core logic. The changes also include a sophisticated partitioning algorithm (Karmarkar-Karp) to create balanced batches. The inclusion of extensive unit and end-to-end tests, especially the one for loss consistency, is commendable and provides high confidence in the correctness of this complex change. I've identified a couple of critical issues, including a typo in an argument name and an incorrect attribute assignment that would lead to runtime errors, along with some suggestions for improving maintainability and robustness. Overall, this is a well-executed feature addition.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant feature: dynamic token-based batching to improve GPU utilization. The implementation includes a new dynamic_batching.py module with the Karmarkar-Karp partitioning algorithm and a major refactoring of the BatchIterator. The worker logic is updated to handle variable-sized batches for gradient accumulation and metric calculation. The changes are well-supported by extensive new tests, including an end-to-end test for loss consistency, which is excellent.
I've found one critical issue and one of medium severity. The critical issue is a bug in the CriticWorkerBase.ppo_train method where the BatchIterator is initialized with incorrect arguments, which will cause a runtime error. The medium severity issue is a confusing error message in an assertion. Once these are addressed, the PR should be in great shape.
| cfg=self.cfg, | ||
| fg=self.cfg, | ||
| dp_size=self.mesh_rank.dp_size, | ||
| drop_last=False, | ||
| worker_type="critic", | ||
| dp_group=dp_group, | ||
| dynamic_bsz=self.cfg.trainer.use_dynamic_batching | ||
| ) | ||
|
|
There was a problem hiding this comment.
The BatchIterator is being initialized incorrectly in CriticWorkerBase.ppo_train. The required positional argument train_data is missing, and an unexpected keyword argument fg is provided. This will cause a TypeError at runtime.
| cfg=self.cfg, | |
| fg=self.cfg, | |
| dp_size=self.mesh_rank.dp_size, | |
| drop_last=False, | |
| worker_type="critic", | |
| dp_group=dp_group, | |
| dynamic_bsz=self.cfg.trainer.use_dynamic_batching | |
| ) | |
| dataloader = BatchIterator( | |
| train_data, | |
| cfg=self.cfg, | |
| dp_size=self.mesh_rank.dp_size, | |
| drop_last=False, | |
| worker_type="critic", | |
| dp_group=dp_group, | |
| dynamic_bsz=self.cfg.trainer.use_dynamic_batching | |
| ) |
| ) | ||
| if not dynamic_bsz: | ||
| assert dataloader.micro_batches_per_mini_batch <= len(dataloader), f"Accumulation Steps {dataloader.micro_batches_per_mini_batch} cannot be less than number of micro batches in total {len(dataloader)}" | ||
|
|
There was a problem hiding this comment.
The assertion logic A <= B is correct, but the error message f"A cannot be less than B" is inverted. It should state that A cannot be greater than B to avoid confusion during debugging.
| assert dataloader.micro_batches_per_mini_batch <= len(dataloader), f"Accumulation Steps {dataloader.micro_batches_per_mini_batch} cannot be greater than number of micro batches in total {len(dataloader)}" |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant feature: dynamic token-based batching to improve GPU utilization. The implementation leverages the Karmarkar-Karp algorithm for balanced partitioning and refactors the BatchIterator to handle both fixed and dynamic batching modes. The changes are extensive, touching configuration, data structures, worker training loops, and adding a comprehensive suite of new tests, including an important end-to-end loss consistency test. Overall, this is a well-structured enhancement. I've identified a couple of critical issues that would lead to runtime errors, which should be addressed.
| dataloader = BatchIterator( | ||
| train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| train_data, | ||
| cfg=self.cfg, | ||
| fg=self.cfg, | ||
| dp_size=self.mesh_rank.dp_size, | ||
| drop_last=False, | ||
| worker_type="critic", | ||
| dp_group=dp_group, | ||
| dynamic_bsz=self.cfg.trainer.use_dynamic_batching | ||
| ) |
There was a problem hiding this comment.
There's an invalid keyword argument fg being passed to the BatchIterator constructor. This will raise a TypeError at runtime as BatchIterator does not accept an fg parameter. This appears to be a typo and should be removed.
dataloader = BatchIterator(
train_data,
cfg=self.cfg,
dp_size=self.mesh_rank.dp_size,
drop_last=False,
worker_type="critic",
dp_group=dp_group,
dynamic_bsz=self.cfg.trainer.use_dynamic_batching
)|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and valuable feature: dynamic batching based on token count. This should improve GPU utilization by packing batches more efficiently. The implementation is well-structured, with new utilities for balanced partitioning and a refactored BatchIterator. The addition of comprehensive unit and integration tests is excellent. However, I've found a critical issue in the gradient accumulation logic for both the policy and critic trainers that could lead to skipped optimizer steps. The fix is straightforward and I've provided suggestions.
| samples_seen += len(experience) | ||
|
|
||
| should_policy_update = (samples_seen % self.policy_mini_batch_size_per_gpu == 0) | ||
|
|
||
| status = self.training_step( | ||
| experience, | ||
| global_step, | ||
| local_step, | ||
| accumulation_steps, | ||
| should_policy_update | ||
| ) | ||
| policy_update_steps += 1 | ||
| if should_policy_update: | ||
| policy_update_steps += 1 |
There was a problem hiding this comment.
The logic for determining when to perform an optimizer step (should_policy_update) is flawed. The use of samples_seen % self.policy_mini_batch_size_per_gpu == 0 is not robust for variable-sized micro-batches from dynamic batching, as samples_seen may never be an exact multiple of the mini-batch size, causing optimizer steps to be skipped. Additionally, samples_seen accumulates over the entire epoch without being reset after an optimizer step, which is incorrect for mini-batch-based gradient accumulation.
A more robust approach is to use a counter that tracks samples since the last update, triggers an update when the threshold is met, and then resets.
| samples_seen += len(experience) | |
| should_policy_update = (samples_seen % self.policy_mini_batch_size_per_gpu == 0) | |
| status = self.training_step( | |
| experience, | |
| global_step, | |
| local_step, | |
| accumulation_steps, | |
| should_policy_update | |
| ) | |
| policy_update_steps += 1 | |
| if should_policy_update: | |
| policy_update_steps += 1 | |
| samples_seen += len(experience) | |
| should_policy_update = samples_seen >= self.policy_mini_batch_size_per_gpu | |
| status = self.training_step( | |
| experience, | |
| global_step, | |
| local_step, | |
| should_policy_update | |
| ) | |
| if should_policy_update: | |
| policy_update_steps += 1 | |
| samples_seen = 0 |
There was a problem hiding this comment.
This is incorrect, right? Since we first split the training batch into mini batches, then split each mini batch into micro batches, it is guaranteed that we always sum up to the mini batch size.
| samples_seen += len(experience) | ||
| should_critic_update = (samples_seen % self.critic_mini_batch_size_per_gpu == 0) | ||
| status = self.training_step(experience, global_step, local_step, should_critic_update) | ||
|
|
||
| if should_critic_update: | ||
| critic_update_steps += 1 |
There was a problem hiding this comment.
The logic for determining should_critic_update has the same flaw as in the policy worker. Using the modulo operator (%) is not safe for variable-sized micro-batches, and the samples_seen counter is not reset after an update. This can lead to skipped optimizer steps.
The logic should be to trigger an update when the accumulated sample count meets or exceeds the mini-batch threshold, and then reset the counter.
| samples_seen += len(experience) | |
| should_critic_update = (samples_seen % self.critic_mini_batch_size_per_gpu == 0) | |
| status = self.training_step(experience, global_step, local_step, should_critic_update) | |
| if should_critic_update: | |
| critic_update_steps += 1 | |
| samples_seen += len(experience) | |
| should_critic_update = (samples_seen >= self.critic_mini_batch_size_per_gpu) | |
| status = self.training_step(experience, global_step, local_step, should_critic_update) | |
| if should_critic_update: | |
| critic_update_steps += 1 | |
| samples_seen = 0 |
| raise ValueError( | ||
| f"mini_batch_size_per_gpu must be positive, got {self.mini_batch_size_per_gpu}. " | ||
| f"Check your configuration: worker_type={self.worker_type}, dp_size={self.dp_size}" | ||
| ) |
There was a problem hiding this comment.
The logic for calculating mini_batch_size_per_gpu can result in a value of 0 if base_mini_batch_size * self.cfg.generator.n_samples_per_prompt is less than self.dp_size. While there is a check for this that raises a ValueError, the error message could be more informative by including the values of base_mini_batch_size and n_samples_per_prompt to help with debugging configuration issues.
| raise ValueError( | |
| f"mini_batch_size_per_gpu must be positive, got {self.mini_batch_size_per_gpu}. " | |
| f"Check your configuration: worker_type={self.worker_type}, dp_size={self.dp_size}" | |
| ) | |
| raise ValueError( | |
| f"mini_batch_size_per_gpu must be positive, got {self.mini_batch_size_per_gpu}. " | |
| f"Check your configuration: base_mini_batch_size={base_mini_batch_size}, " | |
| f"n_samples_per_prompt={self.cfg.generator.n_samples_per_prompt}, dp_size={self.dp_size}" | |
| ) |
tyler-griggs
left a comment
There was a problem hiding this comment.
Leaving a handful of nits. I'm sanity checking the weighted average now, then should be done!
| action_log_probs=batch["action_log_probs"], | ||
| base_action_log_probs=batch["base_action_log_probs"], | ||
| values=batch["values"], | ||
| rollout_logprobs=batch["rollout_log_probs"], |
There was a problem hiding this comment.
I believe this key has an extra _ and should be rollout_logprobs
| micro_train_batch_size_per_gpu: 1 | ||
| micro_forward_batch_size_per_gpu: 1 | ||
| use_dynamic_batching: false # Enable dynamic token-based batching instead of fixed batch sizes | ||
| max_token_len_per_gpu: 4096 # Maximum tokens per micro-batch when dynamic batching is enabled |
There was a problem hiding this comment.
nit: can we suffix this with _train to mirror the micro_train and micro_forward params above?
| self._chunks = self.data.chunk(self.sample_batch_size) | ||
| self._iter = iter(self._chunks) | ||
|
|
||
| logger.info(f"Total batch size: {self.total_batch_size}") |
There was a problem hiding this comment.
nit: Is this helpful to see for you during training? I would recommend either removing it, or adding more detail so that it's clear what it means. In the training log, if I just see "Total batch size", that doesn't really give me enough information to know what's going on :D
|
|
||
| logger.info(f"Total batch size: {self.total_batch_size}") | ||
|
|
||
| logger.info(f"Sizes: {[len(seq) for seq in data["sequences"]]} {dp_size=}") |
There was a problem hiding this comment.
Nit: is this helpful for you to see during training? It's a bit noisy in my logs because it can be quite a long list.
| synced_counts = [int(x) for x in counts_tensor.tolist()] | ||
|
|
||
| logger.info( | ||
| f"[Rank {dist.get_rank()}] BatchIterator sync - " f"Local counts: {local_copy}, Synced: {synced_counts}" |
There was a problem hiding this comment.
nit: consider removing this, or at least making this a debug statement?
| return synced_counts | ||
| else: | ||
| if self.dynamic_bsz: | ||
| logger.info(f"BatchIterator - No distributed sync, using local: {local_counts}") |
| accumulation_steps = micro_batches_per_mini_batch | ||
| if dynamic_bsz: | ||
| logger.info( | ||
| f"Data {len(train_data)} | Dynamic Batching | Max Tokens per GPU {self.cfg.trainer.max_token_len_per_gpu} | Num Micro Batches {len(dataloader)}" |
There was a problem hiding this comment.
Nit: can you change Data to Training Batch Size so that it's more clear in the logs what this is?
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and valuable feature: dynamic batching based on token counts. This should improve GPU utilization by packing batches more efficiently. The implementation is well-structured, with the core logic encapsulated in skyrl_train/utils/dynamic_batching.py and integrated into a refactored BatchIterator. The addition of comprehensive unit and end-to-end tests, especially the loss consistency test, is commendable and crucial for such a complex change.
However, I've identified a critical issue in the training loop's logic for triggering optimizer steps in both the policy and critic workers. The current implementation can lead to missed gradient updates, especially with variable-sized micro-batches. I've provided detailed comments and suggestions to address this. Once this is fixed, this will be an excellent addition to the codebase.
tyler-griggs
left a comment
There was a problem hiding this comment.
One last handful of nits, then I think we're ready to merge!
| def _calculate_dynamic_micro_batch_counts(self, mini_batches: List[TrainingInputBatch]) -> List[int]: | ||
| """Calculate the number of micro-batches needed for each mini-batch based on token counts.""" | ||
| if self.for_inference: | ||
| self.max_token_len = getattr( |
There was a problem hiding this comment.
nit: this is another case where you can assume the config value exists and you don't need to use getattr
| samples_seen += len(experience) | ||
|
|
||
| should_policy_update = (samples_seen % self.policy_mini_batch_size_per_gpu == 0) | ||
|
|
||
| status = self.training_step( | ||
| experience, | ||
| global_step, | ||
| local_step, | ||
| accumulation_steps, | ||
| should_policy_update | ||
| ) | ||
| policy_update_steps += 1 | ||
| if should_policy_update: | ||
| policy_update_steps += 1 |
There was a problem hiding this comment.
This is incorrect, right? Since we first split the training batch into mini batches, then split each mini batch into micro batches, it is guaranteed that we always sum up to the mini batch size.
| else: | ||
| if self.dynamic_bsz: | ||
| logger.info(f"BatchIterator - No distributed sync, using local: {local_counts}") | ||
| logger.info(f"[DEBUG] BatchIterator - No distributed sync, using local: {local_counts}") |
There was a problem hiding this comment.
Oh, I actually meant this should be logger.debug so that we can optionally filter it out by setting the log level to info and avoid the debug statements.
| # We assume that outputs are replicated within tp or sp group, otherwise this is not correct. | ||
| status = self.strategy.all_reduce(status) | ||
| for k in status: | ||
| status[k] = len(experience) * status[k] |
There was a problem hiding this comment.
To sanity check my understanding: this is now the sum total of all metrics across all DP and SP workers and summed across all samples (by weighting each sample by its micro-batch's final metrics), so there is technically a lot of overcounting. But, when we later call reduce_metrics we divide by the total number of samples (accounting for multiple epoch passes over the batch and for world_size), so it will take the avg across all of the samples (weighted by their micro batch's metrics). So then it all works out in the end. Does that match your understanding?
| @staticmethod | ||
| def batch_to_experience(batch: TrainingInputBatch): | ||
| # TODO (sumanthrh): other keys are not permitted right now, can go into info | ||
| # TODO: this conversion is hidden right now, might need to be surfaced in worker explicitly. |
There was a problem hiding this comment.
Can you please keep these comments? Same for the ones below (e.g., additional info, metadata)
| - Multi-GPU distributed training produces consistent results | ||
| - Full e2e pipeline with inference engines and generator input | ||
| """ | ||
| # import asyncio |
| shared_pg=None, | ||
| gpu_memory_utilization=0.8, | ||
| inference_engine_enable_sleep=True, | ||
| inference_engine_enable_sleep=inference_engine_enable_sleep, # Use parameter |
| max_turns=1, | ||
| use_conversation_multi_turn=True, | ||
| max_env_workers=10, | ||
| inference_engine_enable_sleep=True, # Add as parameter with default |



Usage
trainer.max_token_len_per_gpu_forwardtrainer.max_token_len_per_gputrainer.use_dynamic_batchingDetails
BatchIteratorBatchIteratorchunks received data bymini_batch_size_per_gpuppo_trainto be a weighted averageTesting
BatchIteratorfunctionality