-
Notifications
You must be signed in to change notification settings - Fork 247
Add a configuration for max tokens per microbatch #640
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
97e51db
f5d6b95
42bd2c1
35f787e
ce7519b
a49bc65
f436880
4913ba8
dc11f3c
9994903
657b34f
0ec3e65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,7 @@ | |
| from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch | ||
| from skyrl_train.distributed.utils import init_custom_process_group | ||
| from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl | ||
| from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics | ||
| from skyrl_train.workers.worker_utils import BatchIterator, TokenBasedBatchIterator, reduce_metrics | ||
| from skyrl_train.dataset.replay_buffer import Experience | ||
| from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch | ||
| from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient | ||
|
|
@@ -322,16 +322,29 @@ def forward( | |
| ) -> TrainingOutputBatch: | ||
| """Run forward pass on the input batch in inference mode. | ||
|
|
||
| This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.trainer.micro_forward_batch_size_per_gpu`. | ||
| This is a wrapper around `_forward_micro_batch` that runs in micro batches. | ||
| Uses token-based chunking if `max_tokens_per_microbatch` is configured, otherwise | ||
| falls back to sample-based chunking with `micro_forward_batch_size_per_gpu`. | ||
| """ | ||
| # run in micro batches of cfg.trainer.micro_forward_batch_size_per_gpu | ||
| # TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc. | ||
| micro_batches = data.chunk(self.cfg.trainer.micro_forward_batch_size_per_gpu) | ||
| # Check if token-based chunking is enabled | ||
| if self.cfg.trainer.max_tokens_per_microbatch > 0: | ||
| # Use token-based chunking | ||
| micro_batch_iterator = TokenBasedBatchIterator( | ||
| data=data, | ||
| max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, | ||
| ) | ||
| else: | ||
| micro_batch_iterator = BatchIterator( | ||
| data=data, | ||
| sample_batch_size=self.cfg.trainer.micro_forward_batch_size_per_gpu, | ||
| drop_last=False, | ||
| ) | ||
|
|
||
| outputs = [] | ||
| for micro_batch in micro_batches: | ||
| for micro_batch in micro_batch_iterator: | ||
| outputs.append(self._forward_micro_batch(micro_batch)) | ||
| output = TrainingOutputBatch.cat(outputs) | ||
|
|
||
| output = micro_batch_iterator.reorder_microbatches(outputs) | ||
| if output.device is not None and output.device != torch.device("cpu"): | ||
| output = output.to("cpu") | ||
| return output | ||
|
|
@@ -640,15 +653,22 @@ def _normalize_mini_batch_size(self): | |
|
|
||
| def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: | ||
| global_step = train_data.metadata["global_step"] | ||
| dataloader = BatchIterator( | ||
| train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| ) | ||
|
|
||
| micro_batches_per_mini_batch = ( | ||
| self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu | ||
| ) | ||
| # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. | ||
| accumulation_steps = micro_batches_per_mini_batch | ||
| if self.cfg.trainer.max_tokens_per_microbatch is not None: | ||
| dataloader = TokenBasedBatchIterator( | ||
| train_data, max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch | ||
| ) | ||
| accumulation_steps = dataloader.num_microbatches | ||
| else: | ||
| dataloader = BatchIterator( | ||
| train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| ) | ||
|
|
||
| micro_batches_per_mini_batch = ( | ||
| self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu | ||
| ) | ||
| # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. | ||
| accumulation_steps = micro_batches_per_mini_batch | ||
|
Comment on lines
+657
to
+671
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ: Should the dynamic batch size related packing happen at the batch level or the mini-batch level? The most efficient one would be at the batch level, but then we won't have a good way to get mini batches afterwards. It seems like we should just do this at the mini batch level first. (you'd need to introduce an explicit loop over mini batches)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also allows us to naturally define the accumulation steps
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah it should also be done at the global batch level, but I'm leaving that for a follow-up optimization, see the end of the PR description. The global batch level batch balancing is useful to assign each worker a similar number of tokens. This PR's main goal is the max tokens per microbatch parameter. The balancing of microbatches is kind of a convenient feature add-on that is a slight optimization. Here's why the balancing at the microbatch level is useful:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Could you elaborate on this a bit more?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh sorry here's what I meant; We have "mini batches" per worker that is set based on
So for example, with For each mini batch per worker, we further split it into micro batches. My point was that the dynamic batch size related binpacking should happen at the worker-level mini batches not the worker-level overall batch. |
||
|
|
||
| status_list = [] | ||
| all_metrics = defaultdict(list) | ||
|
|
@@ -660,7 +680,9 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: | |
| desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", | ||
| disable=not self.strategy.is_rank_0(), | ||
| ) | ||
| for local_step, experience in enumerate(pbar): | ||
| for local_step, batch in enumerate(pbar): | ||
| experience = BatchIterator.batch_to_experience(batch) | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| print(f"{experience.sequences=}") | ||
| status = self.training_step( | ||
| experience, | ||
| global_step, | ||
|
|
@@ -717,6 +739,7 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: | |
|
|
||
| # should return an `TrainingOutputBatch` | ||
| output = TrainingOutputBatch() | ||
| # NOTE: No need to reorder anything here beacuse we average across the entire batch. | ||
| output.metadata = {"train_status": status_mean} | ||
| return output | ||
|
|
||
|
|
@@ -932,19 +955,29 @@ def save_hf_model(self, export_dir: str, tokenizer): | |
|
|
||
| def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: | ||
| global_step = train_data.metadata["global_step"] | ||
| dataloader = BatchIterator( | ||
| train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| ) | ||
|
|
||
| # TODO: Move this to the base class since it's common to both policy and critic workers. | ||
| if self.cfg.trainer.max_tokens_per_microbatch is not None: | ||
| dataloader = TokenBasedBatchIterator( | ||
| train_data, max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch | ||
| ) | ||
| accumulation_steps = dataloader.num_microbatches | ||
| else: | ||
| dataloader = BatchIterator( | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| ) | ||
|
|
||
| # TODO: Make `num_microbatches` a property of the dataloader instead of computing it here. | ||
| # Ex: see the TokenBasedBatchIterator class. | ||
| micro_batches_per_mini_batch = ( | ||
| self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu | ||
| ) | ||
| # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. | ||
| accumulation_steps = micro_batches_per_mini_batch | ||
|
|
||
| torch.cuda.empty_cache() | ||
| self.model.train() | ||
|
|
||
| micro_batches_per_mini_batch = ( | ||
| self.critic_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu | ||
| ) | ||
| # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. | ||
| accumulation_steps = micro_batches_per_mini_batch | ||
|
|
||
| all_metrics = defaultdict(list) | ||
| critic_update_steps = 0 | ||
| for epoch in range(self.cfg.trainer.update_epochs_per_batch): | ||
|
|
@@ -953,7 +986,8 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: | |
| desc=f"Critic Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", | ||
| disable=not self.strategy.is_rank_0(), | ||
| ) | ||
| for local_step, experience in enumerate(pbar): | ||
| for local_step, batch in enumerate(pbar): | ||
| experience = BatchIterator.batch_to_experience(batch) | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| status = self.training_step(experience, global_step, local_step, accumulation_steps) | ||
| critic_update_steps += 1 | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.