From 87fcd3f4f8e73d2ef1aa0246ce496d88193241a1 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 30 Dec 2025 15:56:32 -0800 Subject: [PATCH 01/11] update ppo_train to have 2 levels of batch iteration (mini and micro) Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 49 +++++++++++-------- .../skyrl_train/workers/worker_utils.py | 6 +-- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 8bbeba30b..94e8bdba3 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -621,7 +621,7 @@ def _normalize_mini_batch_size(self): self.cfg.trainer.policy_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size ) - def forward_backward(self, experience: Experience, accumulation_steps: int) -> Dict[str, float]: + def forward_backward(self, experience: Experience, microbatch_weight: float) -> Dict[str, float]: """ Perform the forward and backward pass for one micro-batch. """ @@ -688,7 +688,7 @@ def forward_backward(self, experience: Experience, accumulation_steps: int) -> D kl_loss_term = kl_loss * self.cfg.trainer.algorithm.kl_loss_coef loss = policy_loss + kl_loss_term - entropy_loss_term - loss = loss / accumulation_steps + loss = loss * microbatch_weight self.strategy.backward(loss, self.model, self.optimizer) status = { @@ -714,39 +714,47 @@ def optim_step(self) -> float: def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: global_step = train_data.metadata["global_step"] - dataloader = BatchIterator( - train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False + minibatch_iterator = BatchIterator( + train_data, sample_batch_size=self.policy_mini_batch_size_per_gpu, drop_last=False ) - micro_batches_per_mini_batch = ( - self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu - ) - # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. - accumulation_steps = micro_batches_per_mini_batch + # micro_batches_per_mini_batch = ( + # self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu + # ) + # # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. + # accumulation_steps = micro_batches_per_mini_batch status_list = [] all_metrics = defaultdict(list) - policy_update_steps = 0 + num_minibatches = len(minibatch_iterator) for epoch in range(self.cfg.trainer.update_epochs_per_batch): - pbar = tqdm( - dataloader, + minibatch_pbar = tqdm( + minibatch_iterator, desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", disable=not self.strategy.is_rank_0(), ) - for local_step, experience in enumerate(pbar): - status = self.forward_backward(experience, accumulation_steps) + for local_step, minibatch in enumerate(minibatch_pbar): + microbatch_iterator = BatchIterator( + minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False + ) + num_microbatches = len(microbatch_iterator) + microbatch_weight = 1.0 / num_microbatches - if (local_step + 1) % accumulation_steps == 0: - grad_norm = self.optim_step() - status["raw_grad_norm"] = grad_norm + for microbatch in microbatch_iterator: + microbatch_experience = BatchIterator.batch_to_experience(microbatch) + status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) + + grad_norm = self.optim_step() + status["raw_grad_norm"] = grad_norm if self.record_memory: + # NOTE: local_step == minibatch index now instead of microbatch index self.save_memory_snapshot(global_step, local_step) status["policy_lr"] = self.scheduler.get_last_lr()[0] - policy_update_steps += 1 + # TODO: Move all the progress bar stuff into a utility function, and then add it back in the inner loop. # for DP # TODO (sumanthrh): this assumes all workers are data parallel. @@ -785,14 +793,14 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: status_list.append(status) for k, v in status.items(): all_metrics[k].append(v) - pbar.set_postfix(short_status) + minibatch_pbar.set_postfix(short_status) torch.distributed.barrier() # not needed beyond status logging all_metrics.pop("response_length", None) status_mean = reduce_metrics(all_metrics) - status_mean["policy_update_steps"] = policy_update_steps / accumulation_steps + status_mean["policy_update_steps"] = num_minibatches # should return an `TrainingOutputBatch` output = TrainingOutputBatch() @@ -803,6 +811,7 @@ def training_step(self, experience: Experience, global_step, local_step, accumul """ Perform one micro-batch of training, accumulate gradients, and step the optimizer only after `accumulation_steps` micro-batches. """ + # TODO: Is this method actually used? status = self.forward_backward(experience, accumulation_steps) if (local_step + 1) % accumulation_steps == 0: diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 43b99d91a..897d032ea 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -37,11 +37,9 @@ def __len__(self): def __iter__(self): return self - def __next__(self) -> Experience: + def __next__(self) -> TrainingInputBatch: try: - batch = next(self._iter) - exp = self.batch_to_experience(batch) - return exp + return next(self._iter) except StopIteration: self._iter = iter(self._chunks) raise StopIteration From 990cb1a8f408c0e076fccd4f82897e205bb83788 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 30 Dec 2025 15:59:38 -0800 Subject: [PATCH 02/11] do for critic base as well Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 66 ++++++++++++----------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 94e8bdba3..50a597108 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -812,7 +812,7 @@ def training_step(self, experience: Experience, global_step, local_step, accumul Perform one micro-batch of training, accumulate gradients, and step the optimizer only after `accumulation_steps` micro-batches. """ # TODO: Is this method actually used? - status = self.forward_backward(experience, accumulation_steps) + status = self.forward_backward(experience, microbatch_weight=1.0 / accumulation_steps) if (local_step + 1) % accumulation_steps == 0: grad_norm = self.optim_step() @@ -906,10 +906,11 @@ def _normalize_mini_batch_size(self): self.cfg.trainer.critic_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size ) - def forward_backward(self, experience: Experience, accumulation_steps: int) -> Dict[str, float]: + def forward_backward(self, experience: Experience, microbatch_weight: float) -> Dict[str, float]: """ Perform the forward and backward pass for one micro-batch. """ + self.model.train() experience.to_device(torch.cuda.current_device()) sequences = experience.sequences @@ -935,7 +936,7 @@ def forward_backward(self, experience: Experience, accumulation_steps: int) -> D config=self.cfg.trainer.algorithm, loss_mask=loss_mask, ) - loss = loss / accumulation_steps + loss = loss * microbatch_weight self.strategy.backward(loss, self.model, self.optimizer) status = { @@ -988,36 +989,38 @@ def save_hf_model(self, export_dir: str, tokenizer): ) def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: - dataloader = BatchIterator( - train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False - ) - - torch.cuda.empty_cache() - self.model.train() - - micro_batches_per_mini_batch = ( - self.critic_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu + global_step = train_data.metadata["global_step"] + minibatch_iterator = BatchIterator( + train_data, sample_batch_size=self.critic_mini_batch_size_per_gpu, drop_last=False ) - # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. - accumulation_steps = micro_batches_per_mini_batch all_metrics = defaultdict(list) - critic_update_steps = 0 + num_minibatches = len(minibatch_iterator) + for epoch in range(self.cfg.trainer.update_epochs_per_batch): - pbar = tqdm( - dataloader, + minibatch_pbar = tqdm( + minibatch_iterator, desc=f"Critic Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", disable=not self.strategy.is_rank_0(), ) - for local_step, experience in enumerate(pbar): - status = self.forward_backward(experience, accumulation_steps) + for local_step, minibatch in enumerate(minibatch_pbar): + microbatch_iterator = BatchIterator( + minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False + ) + num_microbatches = len(microbatch_iterator) + microbatch_weight = 1.0 / num_microbatches + + for microbatch in microbatch_iterator: + microbatch_experience = BatchIterator.batch_to_experience(microbatch) + status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) - if (local_step + 1) % accumulation_steps == 0: - grad_norm = self.optim_step() - status["raw_grad_norm"] = grad_norm + grad_norm = self.optim_step() + status["raw_grad_norm"] = grad_norm + + if self.record_memory: + self.save_memory_snapshot(global_step, local_step) status["critic_lr"] = self.scheduler.get_last_lr()[0] - critic_update_steps += 1 # for DP # TODO (sumanthrh): this assumes all workers are data parallel. @@ -1027,28 +1030,29 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: for k, v in status.items(): all_metrics[k].append(v) - pbar.set_postfix(status) + minibatch_pbar.set_postfix(status) torch.distributed.barrier() status_mean = reduce_metrics(all_metrics) - status_mean["critic_update_steps"] = critic_update_steps / accumulation_steps + status_mean["critic_update_steps"] = num_minibatches output = TrainingOutputBatch() output.metadata = {"train_status": status_mean} return output - def training_step(self, experience: Experience, global_step, local_step, accumulation_steps) -> Dict[str, float]: + def training_step(self, experience: Experience, global_step, local_step, microbatch_weight) -> Dict[str, float]: """ - Perform one micro-batch of training, accumulate gradients, and step the optimizer only after `accumulation_steps` micro-batches. + Perform one micro-batch of training, accumulate gradients, and step the optimizer only after all micro-batches. """ - status = self.forward_backward(experience, accumulation_steps) + # TODO: Is this method actually used? + status = self.forward_backward(experience, microbatch_weight) - if (local_step + 1) % accumulation_steps == 0: - grad_norm = self.optim_step() - status["raw_grad_norm"] = grad_norm + if self.record_memory: + self.save_memory_snapshot(global_step, local_step) status["critic_lr"] = self.scheduler.get_last_lr()[0] + return status def save_checkpoint(self, ckpt_dir: str, tokenizer=None): From 94b59838f777bc1d9a969c509a70a8352cd44553 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 30 Dec 2025 16:00:10 -0800 Subject: [PATCH 03/11] fix training_step Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 50a597108..c43a552e7 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -1041,12 +1041,12 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: output.metadata = {"train_status": status_mean} return output - def training_step(self, experience: Experience, global_step, local_step, microbatch_weight) -> Dict[str, float]: + def training_step(self, experience: Experience, global_step, local_step, accumulation_steps) -> Dict[str, float]: """ Perform one micro-batch of training, accumulate gradients, and step the optimizer only after all micro-batches. """ # TODO: Is this method actually used? - status = self.forward_backward(experience, microbatch_weight) + status = self.forward_backward(experience, microbatch_weight=1.0 / accumulation_steps) if self.record_memory: self.save_memory_snapshot(global_step, local_step) From 108d96066bdc719f1d3728262131d6f58545e150 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 30 Dec 2025 16:01:02 -0800 Subject: [PATCH 04/11] remove comment Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index c43a552e7..9a726147e 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -718,12 +718,6 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: train_data, sample_batch_size=self.policy_mini_batch_size_per_gpu, drop_last=False ) - # micro_batches_per_mini_batch = ( - # self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu - # ) - # # The number of steps (over micro batches) to accumulate gradients before taking an optimizer step. - # accumulation_steps = micro_batches_per_mini_batch - status_list = [] all_metrics = defaultdict(list) num_minibatches = len(minibatch_iterator) From d301ebf55c3bf7369ddf595c806fe461f6788b16 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 30 Dec 2025 16:12:15 -0800 Subject: [PATCH 05/11] add back optim step Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 9a726147e..f69c36e2a 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -1042,6 +1042,10 @@ def training_step(self, experience: Experience, global_step, local_step, accumul # TODO: Is this method actually used? status = self.forward_backward(experience, microbatch_weight=1.0 / accumulation_steps) + if (local_step + 1) % accumulation_steps == 0: + grad_norm = self.optim_step() + status["raw_grad_norm"] = grad_norm + if self.record_memory: self.save_memory_snapshot(global_step, local_step) From e3b9b0873eeb5ab0a2503993e0652feca33ec330 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 31 Dec 2025 16:34:25 -0800 Subject: [PATCH 06/11] add back status updates per microbatch Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 108 ++++++++++++---------- 1 file changed, 60 insertions(+), 48 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index ed90af3fe..87eebe2a3 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -721,6 +721,51 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: status_list = [] all_metrics = defaultdict(list) num_minibatches = len(minibatch_iterator) + local_step = 0 + + def record_status(status: Dict[str, float]): + """Record the aggregated (all-reduced) training status for the latest microbatch. + Also, update the progress bar with the latest status.""" + status["policy_lr"] = self.scheduler.get_last_lr()[0] + + # for DP + # TODO (sumanthrh): this assumes all workers are data parallel. + # We assume that outputs are replicated within tp or sp group, otherwise this is not correct. + status = self.strategy.all_reduce(status) + + # weighted mean for kl + # TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch. + # we can log this in the driver + # if "kl" in status: + # status["kl"] *= status["response_length"] + # status["kl"] /= status["response_length"] + + short_status = {} + + if "policy_loss" in status: + short_status = { + "pg": status["policy_loss"], + "glen": status["response_length"], + "policy_lr": status["policy_lr"], + "ent": status["policy_entropy"], + } + if "raw_grad_norm" in status: + short_status["grad_norm"] = status["raw_grad_norm"] + if "reward" in status: + short_status["rm"] = status["reward"] + + if "critic_loss" in status: + short_status["cri"] = status["critic_loss"] + short_status["vals"] = status["values"] + short_status["cri_lr"] = status["critic_lr"] + + if "ptx_loss" in status: + short_status["ptx"] = status["ptx_loss"] + + status_list.append(status) + for k, v in status.items(): + all_metrics[k].append(v) + minibatch_pbar.set_postfix(short_status) for epoch in range(self.cfg.trainer.update_epochs_per_batch): minibatch_pbar = tqdm( @@ -728,67 +773,34 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", disable=not self.strategy.is_rank_0(), ) - for local_step, minibatch in enumerate(minibatch_pbar): + for minibatch in minibatch_pbar: microbatch_iterator = BatchIterator( minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False ) num_microbatches = len(microbatch_iterator) microbatch_weight = 1.0 / num_microbatches - for microbatch in microbatch_iterator: + for microbatch_idx, microbatch in enumerate(microbatch_iterator): microbatch_experience = BatchIterator.batch_to_experience(microbatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) - grad_norm = self.optim_step() - if grad_norm is not None: - status["raw_grad_norm"] = grad_norm - - if self.record_memory: - # NOTE: local_step == minibatch index now instead of microbatch index - self.save_memory_snapshot(global_step, local_step) + if self.record_memory: + self.save_memory_snapshot(global_step, local_step) - status["policy_lr"] = self.scheduler.get_last_lr()[0] + # Record status for all but the last microbatch in the minibatch. + # The last microbatch should be recorded after the optimizer step. + if microbatch_idx < num_microbatches - 1: + record_status(status) - # TODO: Move all the progress bar stuff into a utility function, and then add it back in the inner loop. + # Local step counts the number of processed microbatches. + local_step += 1 - # for DP - # TODO (sumanthrh): this assumes all workers are data parallel. - # We assume that outputs are replicated within tp or sp group, otherwise this is not correct. - status = self.strategy.all_reduce(status) + grad_norm = self.optim_step() + if grad_norm is not None: + status["raw_grad_norm"] = grad_norm - # weighted mean for kl - # TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch. - # we can log this in the driver - # if "kl" in status: - # status["kl"] *= status["response_length"] - # status["kl"] /= status["response_length"] - - short_status = {} - - if "policy_loss" in status: - short_status = { - "pg": status["policy_loss"], - "glen": status["response_length"], - "policy_lr": status["policy_lr"], - "ent": status["policy_entropy"], - } - if "raw_grad_norm" in status: - short_status["grad_norm"] = status["raw_grad_norm"] - if "reward" in status: - short_status["rm"] = status["reward"] - - if "critic_loss" in status: - short_status["cri"] = status["critic_loss"] - short_status["vals"] = status["values"] - short_status["cri_lr"] = status["critic_lr"] - - if "ptx_loss" in status: - short_status["ptx"] = status["ptx_loss"] - - status_list.append(status) - for k, v in status.items(): - all_metrics[k].append(v) - minibatch_pbar.set_postfix(short_status) + # Record status for the last microbatch in the minibatch. + record_status(status) torch.distributed.barrier() # not needed beyond status logging From 5038324ca632fd5687151718bdf62e20f69dd868 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 31 Dec 2025 16:53:55 -0800 Subject: [PATCH 07/11] fix for critic as well Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 46 ++++++++++++++--------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 87eebe2a3..116f7d2a7 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -784,12 +784,11 @@ def record_status(status: Dict[str, float]): microbatch_experience = BatchIterator.batch_to_experience(microbatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) - if self.record_memory: - self.save_memory_snapshot(global_step, local_step) - # Record status for all but the last microbatch in the minibatch. # The last microbatch should be recorded after the optimizer step. if microbatch_idx < num_microbatches - 1: + if self.record_memory: + self.save_memory_snapshot(global_step, local_step) record_status(status) # Local step counts the number of processed microbatches. @@ -799,6 +798,9 @@ def record_status(status: Dict[str, float]): if grad_norm is not None: status["raw_grad_norm"] = grad_norm + if self.record_memory: + self.save_memory_snapshot(global_step, local_step) + # Record status for the last microbatch in the minibatch. record_status(status) @@ -1004,6 +1006,20 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: all_metrics = defaultdict(list) num_minibatches = len(minibatch_iterator) + local_step = 0 + + def record_status(status: Dict[str, float]): + status["critic_lr"] = self.scheduler.get_last_lr()[0] + + # for DP + # TODO (sumanthrh): this assumes all workers are data parallel. + # We should get more accurate metrics with seq parallel or TP. + # There are metrics like entropy where we get average over local data size + status = self.strategy.all_reduce(status) + + for k, v in status.items(): + all_metrics[k].append(v) + minibatch_pbar.set_postfix(status) for epoch in range(self.cfg.trainer.update_epochs_per_batch): minibatch_pbar = tqdm( @@ -1011,35 +1027,31 @@ def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: desc=f"Critic Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]", disable=not self.strategy.is_rank_0(), ) - for local_step, minibatch in enumerate(minibatch_pbar): + for minibatch in minibatch_pbar: microbatch_iterator = BatchIterator( minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False ) num_microbatches = len(microbatch_iterator) microbatch_weight = 1.0 / num_microbatches - for microbatch in microbatch_iterator: + for microbatch_idx, microbatch in enumerate(microbatch_iterator): microbatch_experience = BatchIterator.batch_to_experience(microbatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) + if microbatch_idx < num_microbatches - 1: + if self.record_memory: + self.save_memory_snapshot(global_step, local_step) + record_status(status) + + local_step += 1 + grad_norm = self.optim_step() if grad_norm is not None: status["raw_grad_norm"] = grad_norm if self.record_memory: self.save_memory_snapshot(global_step, local_step) - - status["critic_lr"] = self.scheduler.get_last_lr()[0] - - # for DP - # TODO (sumanthrh): this assumes all workers are data parallel. - # We should get more accurate metrics with seq parallel or TP. - # There are metrics like entropy where we get average over local data size - status = self.strategy.all_reduce(status) - - for k, v in status.items(): - all_metrics[k].append(v) - minibatch_pbar.set_postfix(status) + record_status(status) torch.distributed.barrier() From a31d8a4cffb6acbac0863204610e4201c32779d3 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 31 Dec 2025 16:56:42 -0800 Subject: [PATCH 08/11] update megatron Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/megatron/megatron_worker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index d8926c4d6..1e2025a48 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -534,8 +534,10 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": disable=not self.strategy.is_rank_0(), ) + # TODO: Convert this into 2 loops for minibatches and microbatches. micro_buffer = [] - for local_step, experience in enumerate(pbar): + for local_step, microbatch in enumerate(pbar): + experience = BatchIterator.batch_to_experience(microbatch) experience.to_device(torch.cuda.current_device()) sequences = experience.sequences attention_mask = experience.attention_mask From c5fbc2545aeed8982a624ae3ae8406d593500734 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Mon, 5 Jan 2026 15:49:08 -0800 Subject: [PATCH 09/11] fix total policy update steps calculation Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 707efdff0..02675d477 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -810,7 +810,7 @@ def record_status(status: Dict[str, float]): all_metrics.pop("response_length", None) status_mean = reduce_metrics(all_metrics) - status_mean["policy_update_steps"] = num_minibatches + status_mean["policy_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch # should return an `TrainingOutputBatch` output = TrainingOutputBatch() From 74dade46dccca37e0dd89c4535ca74117fe452ce Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 10:04:39 -0800 Subject: [PATCH 10/11] fix critic update steps Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 02675d477..e7f4070c3 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -1057,7 +1057,7 @@ def record_status(status: Dict[str, float]): torch.distributed.barrier() status_mean = reduce_metrics(all_metrics) - status_mean["critic_update_steps"] = num_minibatches + status_mean["critic_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch output = TrainingOutputBatch() output.metadata = {"train_status": status_mean} From 6c8212d7bd84c8c4d808611cda13383bd2b6e711 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 10:14:09 -0800 Subject: [PATCH 11/11] fix text + add some docstring Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 13 +++++++++++-- skyrl-train/tests/cpu/test_trainer.py | 17 +++++++++-------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index e7f4070c3..2d56a6679 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -623,8 +623,17 @@ def _normalize_mini_batch_size(self): ) def forward_backward(self, experience: Experience, microbatch_weight: float) -> Dict[str, float]: - """ - Perform the forward and backward pass for one micro-batch. + """Perform the forward and backward pass for one micro-batch. + + Args: + experience: The microbatch data to run the forward and backward pass on. + microbatch_weight: Weight of the microbatch, used to scale the loss contribution + for the microbatch. For example, if you accumulate gradients over 2 microbatches, + then each microbatch should have a weight of 1/2. + + Returns: + Dict containing the status (including loss and some other metrics) + for the microbatch. """ self.model.train() experience.to_device(torch.cuda.current_device()) diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index a1bdfcf3e..599f8a3f4 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -548,8 +548,8 @@ def create_test_worker(worker_class): # Mock forward_backward and optim_step to track calls and verify accumulation behavior policy_forward_backward_calls = [] - def mock_policy_forward_backward(experience, accumulation_steps): - policy_forward_backward_calls.append({"accumulation_steps": accumulation_steps}) + def mock_policy_forward_backward(experience, microbatch_weight): + policy_forward_backward_calls.append({"microbatch_weight": microbatch_weight}) return {"policy_loss": 0.5, "ppo_clip_ratio": 0.1, "policy_entropy": 2.0, "response_length": response_length} policy_worker.forward_backward = mock_policy_forward_backward @@ -568,6 +568,7 @@ def mock_policy_forward_backward(experience, accumulation_steps): ) # 6 // 2 = 3 # New logic: accumulation_steps = micro_batches_per_mini_batch (accumulate within mini-batch) expected_accumulation_steps = micro_batches_per_mini_batch # Should be 3 + expected_microbatch_weight = 1.0 / expected_accumulation_steps # Run policy ppo_train with minimal mocking with ( @@ -584,8 +585,8 @@ def mock_policy_forward_backward(experience, accumulation_steps): # Verify accumulation_steps are consistent (should equal micro_batches_per_mini_batch) for call in policy_forward_backward_calls: assert ( - call["accumulation_steps"] == expected_accumulation_steps - ), f"PolicyWorker: Expected accumulation_steps={expected_accumulation_steps}, got {call['accumulation_steps']}" + call["microbatch_weight"] == expected_microbatch_weight + ), f"PolicyWorker: Expected microbatch_weight={expected_microbatch_weight}, got {call['microbatch_weight']}" # Verify result structure assert "train_status" in result.metadata @@ -602,8 +603,8 @@ def mock_policy_forward_backward(experience, accumulation_steps): # Mock forward_backward and optim_step for critic critic_forward_backward_calls = [] - def mock_critic_forward_backward(experience, accumulation_steps): - critic_forward_backward_calls.append({"accumulation_steps": accumulation_steps}) + def mock_critic_forward_backward(experience, microbatch_weight): + critic_forward_backward_calls.append({"microbatch_weight": microbatch_weight}) return {"critic_loss": 0.3, "values": 1.0} critic_worker.forward_backward = mock_critic_forward_backward @@ -627,8 +628,8 @@ def mock_critic_forward_backward(experience, accumulation_steps): # Verify accumulation_steps are consistent for critic (should equal micro_batches_per_mini_batch) for call in critic_forward_backward_calls: assert ( - call["accumulation_steps"] == expected_accumulation_steps - ), f"CriticWorker: Expected accumulation_steps={expected_accumulation_steps}, got {call['accumulation_steps']}" + call["microbatch_weight"] == expected_microbatch_weight + ), f"CriticWorker: Expected microbatch_weight={expected_microbatch_weight}, got {call['microbatch_weight']}" # Verify result structure for critic assert "train_status" in result.metadata