-
Notifications
You must be signed in to change notification settings - Fork 254
[skyrl-train] Refactor training loop structure to explicitly batch at two levels (minibatch -> microbatch) #817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
87fcd3f
990cb1a
94b5983
108d960
d301ebf
c9ad4c9
e3b9b08
5038324
a31d8a4
25d30a1
c5fbc25
74dade4
6c8212d
70776f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -622,9 +622,18 @@ def _normalize_mini_batch_size(self): | |
| self.cfg.trainer.policy_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size | ||
| ) | ||
|
|
||
| def forward_backward(self, experience: Experience, accumulation_steps: int) -> Dict[str, float]: | ||
| """ | ||
| Perform the forward and backward pass for one micro-batch. | ||
| def forward_backward(self, experience: Experience, microbatch_weight: float) -> Dict[str, float]: | ||
| """Perform the forward and backward pass for one micro-batch. | ||
|
|
||
| Args: | ||
| experience: The microbatch data to run the forward and backward pass on. | ||
| microbatch_weight: Weight of the microbatch, used to scale the loss contribution | ||
| for the microbatch. For example, if you accumulate gradients over 2 microbatches, | ||
| then each microbatch should have a weight of 1/2. | ||
|
|
||
| Returns: | ||
| Dict containing the status (including loss and some other metrics) | ||
| for the microbatch. | ||
| """ | ||
| self.model.train() | ||
| experience.to_device(torch.cuda.current_device()) | ||
|
|
@@ -689,7 +698,7 @@ def forward_backward(self, experience: Experience, accumulation_steps: int) -> D | |
| kl_loss_term = kl_loss * self.cfg.trainer.algorithm.kl_loss_coef | ||
|
|
||
| loss = policy_loss + kl_loss_term - entropy_loss_term | ||
| loss = loss / accumulation_steps | ||
| loss = loss * microbatch_weight | ||
| self.strategy.backward(loss, self.model, self.optimizer) | ||
|
|
||
| status = { | ||
|
|
@@ -715,86 +724,102 @@ def optim_step(self) -> float: | |
|
|
||
| def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: | ||
| global_step = train_data.metadata["global_step"] | ||
| dataloader = BatchIterator( | ||
| train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| minibatch_iterator = BatchIterator( | ||
| train_data, sample_batch_size=self.policy_mini_batch_size_per_gpu, drop_last=False | ||
| ) | ||
|
|
||
| micro_batches_per_mini_batch = ( | ||
| self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu | ||
| ) | ||
| # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. | ||
| accumulation_steps = micro_batches_per_mini_batch | ||
|
|
||
| status_list = [] | ||
| all_metrics = defaultdict(list) | ||
| policy_update_steps = 0 | ||
| num_minibatches = len(minibatch_iterator) | ||
| local_step = 0 | ||
|
|
||
| def record_status(status: Dict[str, float]): | ||
| """Record the aggregated (all-reduced) training status for the latest microbatch. | ||
| Also, update the progress bar with the latest status.""" | ||
| status["policy_lr"] = self.scheduler.get_last_lr()[0] | ||
|
|
||
| # for DP | ||
| # TODO (sumanthrh): this assumes all workers are data parallel. | ||
| # We assume that outputs are replicated within tp or sp group, otherwise this is not correct. | ||
| status = self.strategy.all_reduce(status) | ||
|
|
||
| # weighted mean for kl | ||
| # TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch. | ||
| # we can log this in the driver | ||
| # if "kl" in status: | ||
| # status["kl"] *= status["response_length"] | ||
| # status["kl"] /= status["response_length"] | ||
|
|
||
| short_status = {} | ||
|
|
||
| if "policy_loss" in status: | ||
| short_status = { | ||
| "pg": status["policy_loss"], | ||
| "glen": status["response_length"], | ||
| "policy_lr": status["policy_lr"], | ||
| "ent": status["policy_entropy"], | ||
| } | ||
| if "raw_grad_norm" in status: | ||
| short_status["grad_norm"] = status["raw_grad_norm"] | ||
| if "reward" in status: | ||
| short_status["rm"] = status["reward"] | ||
|
|
||
| if "critic_loss" in status: | ||
| short_status["cri"] = status["critic_loss"] | ||
| short_status["vals"] = status["values"] | ||
| short_status["cri_lr"] = status["critic_lr"] | ||
|
|
||
| if "ptx_loss" in status: | ||
| short_status["ptx"] = status["ptx_loss"] | ||
|
|
||
| status_list.append(status) | ||
| for k, v in status.items(): | ||
| all_metrics[k].append(v) | ||
| minibatch_pbar.set_postfix(short_status) | ||
erictang000 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| for epoch in range(self.cfg.trainer.update_epochs_per_batch): | ||
| pbar = tqdm( | ||
| dataloader, | ||
| minibatch_pbar = tqdm( | ||
| minibatch_iterator, | ||
| desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", | ||
| disable=not self.strategy.is_rank_0(), | ||
| ) | ||
| for local_step, experience in enumerate(pbar): | ||
| status = self.forward_backward(experience, accumulation_steps) | ||
|
|
||
| if (local_step + 1) % accumulation_steps == 0: | ||
| grad_norm = self.optim_step() | ||
| if grad_norm is not None: | ||
| status["raw_grad_norm"] = grad_norm | ||
| for minibatch in minibatch_pbar: | ||
| microbatch_iterator = BatchIterator( | ||
| minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| ) | ||
| num_microbatches = len(microbatch_iterator) | ||
| microbatch_weight = 1.0 / num_microbatches | ||
|
|
||
| for microbatch_idx, microbatch in enumerate(microbatch_iterator): | ||
| microbatch_experience = BatchIterator.batch_to_experience(microbatch) | ||
| status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) | ||
|
|
||
| # Record status for all but the last microbatch in the minibatch. | ||
| # The last microbatch should be recorded after the optimizer step. | ||
| if microbatch_idx < num_microbatches - 1: | ||
| if self.record_memory: | ||
| self.save_memory_snapshot(global_step, local_step) | ||
justinvyu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| record_status(status) | ||
|
|
||
| # Local step counts the number of processed microbatches. | ||
| local_step += 1 | ||
|
|
||
| grad_norm = self.optim_step() | ||
| if grad_norm is not None: | ||
| status["raw_grad_norm"] = grad_norm | ||
|
|
||
| if self.record_memory: | ||
| self.save_memory_snapshot(global_step, local_step) | ||
|
|
||
| status["policy_lr"] = self.scheduler.get_last_lr()[0] | ||
|
|
||
| policy_update_steps += 1 | ||
|
|
||
| # for DP | ||
| # TODO (sumanthrh): this assumes all workers are data parallel. | ||
| # We assume that outputs are replicated within tp or sp group, otherwise this is not correct. | ||
| status = self.strategy.all_reduce(status) | ||
|
|
||
| # weighted mean for kl | ||
| # TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch. | ||
| # we can log this in the driver | ||
| # if "kl" in status: | ||
| # status["kl"] *= status["response_length"] | ||
| # status["kl"] /= status["response_length"] | ||
|
|
||
| short_status = {} | ||
|
|
||
| if "policy_loss" in status: | ||
| short_status = { | ||
| "pg": status["policy_loss"], | ||
| "glen": status["response_length"], | ||
| "policy_lr": status["policy_lr"], | ||
| "ent": status["policy_entropy"], | ||
| } | ||
| if "raw_grad_norm" in status: | ||
| short_status["grad_norm"] = status["raw_grad_norm"] | ||
| if "reward" in status: | ||
| short_status["rm"] = status["reward"] | ||
|
|
||
| if "critic_loss" in status: | ||
| short_status["cri"] = status["critic_loss"] | ||
| short_status["vals"] = status["values"] | ||
| short_status["cri_lr"] = status["critic_lr"] | ||
|
|
||
| if "ptx_loss" in status: | ||
| short_status["ptx"] = status["ptx_loss"] | ||
|
|
||
| status_list.append(status) | ||
| for k, v in status.items(): | ||
| all_metrics[k].append(v) | ||
| pbar.set_postfix(short_status) | ||
| # Record status for the last microbatch in the minibatch. | ||
| record_status(status) | ||
|
|
||
| torch.distributed.barrier() | ||
| # not needed beyond status logging | ||
| all_metrics.pop("response_length", None) | ||
|
|
||
| status_mean = reduce_metrics(all_metrics) | ||
| status_mean["policy_update_steps"] = policy_update_steps / accumulation_steps | ||
| status_mean["policy_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch... |
||
|
|
||
| # should return an `TrainingOutputBatch` | ||
| output = TrainingOutputBatch() | ||
|
|
@@ -882,10 +907,11 @@ def _normalize_mini_batch_size(self): | |
| self.cfg.trainer.critic_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size | ||
| ) | ||
|
|
||
| def forward_backward(self, experience: Experience, accumulation_steps: int) -> Dict[str, float]: | ||
| def forward_backward(self, experience: Experience, microbatch_weight: float) -> Dict[str, float]: | ||
| """ | ||
| Perform the forward and backward pass for one micro-batch. | ||
| """ | ||
| self.model.train() | ||
| experience.to_device(torch.cuda.current_device()) | ||
|
|
||
| sequences = experience.sequences | ||
|
|
@@ -911,7 +937,7 @@ def forward_backward(self, experience: Experience, accumulation_steps: int) -> D | |
| config=self.cfg.trainer.algorithm, | ||
| loss_mask=loss_mask, | ||
| ) | ||
| loss = loss / accumulation_steps | ||
| loss = loss * microbatch_weight | ||
| self.strategy.backward(loss, self.model, self.optimizer) | ||
|
|
||
| status = { | ||
|
|
@@ -964,52 +990,64 @@ def save_hf_model(self, export_dir: str, tokenizer): | |
| ) | ||
|
|
||
| def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: | ||
| dataloader = BatchIterator( | ||
| train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| global_step = train_data.metadata["global_step"] | ||
| minibatch_iterator = BatchIterator( | ||
| train_data, sample_batch_size=self.critic_mini_batch_size_per_gpu, drop_last=False | ||
| ) | ||
|
|
||
| torch.cuda.empty_cache() | ||
| self.model.train() | ||
| all_metrics = defaultdict(list) | ||
| num_minibatches = len(minibatch_iterator) | ||
| local_step = 0 | ||
|
|
||
| micro_batches_per_mini_batch = ( | ||
| self.critic_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu | ||
| ) | ||
| # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. | ||
| accumulation_steps = micro_batches_per_mini_batch | ||
| def record_status(status: Dict[str, float]): | ||
| status["critic_lr"] = self.scheduler.get_last_lr()[0] | ||
|
|
||
| # for DP | ||
| # TODO (sumanthrh): this assumes all workers are data parallel. | ||
| # We should get more accurate metrics with seq parallel or TP. | ||
| # There are metrics like entropy where we get average over local data size | ||
| status = self.strategy.all_reduce(status) | ||
|
|
||
| for k, v in status.items(): | ||
| all_metrics[k].append(v) | ||
| minibatch_pbar.set_postfix(status) | ||
|
|
||
| all_metrics = defaultdict(list) | ||
| critic_update_steps = 0 | ||
| for epoch in range(self.cfg.trainer.update_epochs_per_batch): | ||
| pbar = tqdm( | ||
| dataloader, | ||
| minibatch_pbar = tqdm( | ||
| minibatch_iterator, | ||
| desc=f"Critic Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", | ||
| disable=not self.strategy.is_rank_0(), | ||
| ) | ||
| for local_step, experience in enumerate(pbar): | ||
| status = self.forward_backward(experience, accumulation_steps) | ||
| for minibatch in minibatch_pbar: | ||
| microbatch_iterator = BatchIterator( | ||
| minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| ) | ||
| num_microbatches = len(microbatch_iterator) | ||
| microbatch_weight = 1.0 / num_microbatches | ||
|
|
||
| for microbatch_idx, microbatch in enumerate(microbatch_iterator): | ||
| microbatch_experience = BatchIterator.batch_to_experience(microbatch) | ||
| status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) | ||
|
|
||
| if (local_step + 1) % accumulation_steps == 0: | ||
| grad_norm = self.optim_step() | ||
| if grad_norm is not None: | ||
| status["raw_grad_norm"] = grad_norm | ||
| if microbatch_idx < num_microbatches - 1: | ||
| if self.record_memory: | ||
| self.save_memory_snapshot(global_step, local_step) | ||
erictang000 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| record_status(status) | ||
|
|
||
| status["critic_lr"] = self.scheduler.get_last_lr()[0] | ||
| critic_update_steps += 1 | ||
| local_step += 1 | ||
|
|
||
| # for DP | ||
| # TODO (sumanthrh): this assumes all workers are data parallel. | ||
| # We should get more accurate metrics with seq parallel or TP. | ||
| # There are metrics like entropy where we get average over local data size | ||
| status = self.strategy.all_reduce(status) | ||
| grad_norm = self.optim_step() | ||
| if grad_norm is not None: | ||
| status["raw_grad_norm"] = grad_norm | ||
|
|
||
| for k, v in status.items(): | ||
| all_metrics[k].append(v) | ||
| pbar.set_postfix(status) | ||
| if self.record_memory: | ||
| self.save_memory_snapshot(global_step, local_step) | ||
| record_status(status) | ||
|
|
||
| torch.distributed.barrier() | ||
|
|
||
| status_mean = reduce_metrics(all_metrics) | ||
| status_mean["critic_update_steps"] = critic_update_steps / accumulation_steps | ||
| status_mean["critic_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch | ||
|
|
||
| output = TrainingOutputBatch() | ||
| output.metadata = {"train_status": status_mean} | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -37,11 +37,9 @@ def __len__(self): | |||||
| def __iter__(self): | ||||||
| return self | ||||||
|
|
||||||
| def __next__(self) -> Experience: | ||||||
| def __next__(self) -> TrainingInputBatch: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this change also affects the use of
could you make sure the conversion to experience is also handled correctly for the megatron code path? Making sure one of these tests pass:
is probably a good way to check this.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. I added a followup TODO to update the megatron worker's ppo loop as well. Made a minimal change for now to prevent this PR from getting too large. |
||||||
| try: | ||||||
| batch = next(self._iter) | ||||||
| exp = self.batch_to_experience(batch) | ||||||
| return exp | ||||||
| return next(self._iter) | ||||||
| except StopIteration: | ||||||
| self._iter = iter(self._chunks) | ||||||
| raise StopIteration | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.