Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion skyrl-train/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,10 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch":
disable=not self.strategy.is_rank_0(),
)

# TODO: Convert this into 2 loops for minibatches and microbatches.
micro_buffer = []
for local_step, experience in enumerate(pbar):
for local_step, microbatch in enumerate(pbar):
experience = BatchIterator.batch_to_experience(microbatch)
experience.to_device(torch.cuda.current_device())
sequences = experience.sequences
attention_mask = experience.attention_mask
Expand Down
232 changes: 135 additions & 97 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,18 @@ def _normalize_mini_batch_size(self):
self.cfg.trainer.policy_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size
)

def forward_backward(self, experience: Experience, accumulation_steps: int) -> Dict[str, float]:
"""
Perform the forward and backward pass for one micro-batch.
def forward_backward(self, experience: Experience, microbatch_weight: float) -> Dict[str, float]:
"""Perform the forward and backward pass for one micro-batch.

Args:
experience: The microbatch data to run the forward and backward pass on.
microbatch_weight: Weight of the microbatch, used to scale the loss contribution
for the microbatch. For example, if you accumulate gradients over 2 microbatches,
then each microbatch should have a weight of 1/2.

Returns:
Dict containing the status (including loss and some other metrics)
for the microbatch.
"""
self.model.train()
experience.to_device(torch.cuda.current_device())
Expand Down Expand Up @@ -689,7 +698,7 @@ def forward_backward(self, experience: Experience, accumulation_steps: int) -> D
kl_loss_term = kl_loss * self.cfg.trainer.algorithm.kl_loss_coef

loss = policy_loss + kl_loss_term - entropy_loss_term
loss = loss / accumulation_steps
loss = loss * microbatch_weight
self.strategy.backward(loss, self.model, self.optimizer)

status = {
Expand All @@ -715,86 +724,102 @@ def optim_step(self) -> float:

def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch:
global_step = train_data.metadata["global_step"]
dataloader = BatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
minibatch_iterator = BatchIterator(
train_data, sample_batch_size=self.policy_mini_batch_size_per_gpu, drop_last=False
)

micro_batches_per_mini_batch = (
self.policy_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu
)
# The number of steps (over micro batches) to accumulate gradients before taking an optimizer step.
accumulation_steps = micro_batches_per_mini_batch

status_list = []
all_metrics = defaultdict(list)
policy_update_steps = 0
num_minibatches = len(minibatch_iterator)
local_step = 0

def record_status(status: Dict[str, float]):
"""Record the aggregated (all-reduced) training status for the latest microbatch.
Also, update the progress bar with the latest status."""
status["policy_lr"] = self.scheduler.get_last_lr()[0]

# for DP
# TODO (sumanthrh): this assumes all workers are data parallel.
# We assume that outputs are replicated within tp or sp group, otherwise this is not correct.
status = self.strategy.all_reduce(status)

# weighted mean for kl
# TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch.
# we can log this in the driver
# if "kl" in status:
# status["kl"] *= status["response_length"]
# status["kl"] /= status["response_length"]

short_status = {}

if "policy_loss" in status:
short_status = {
"pg": status["policy_loss"],
"glen": status["response_length"],
"policy_lr": status["policy_lr"],
"ent": status["policy_entropy"],
}
if "raw_grad_norm" in status:
short_status["grad_norm"] = status["raw_grad_norm"]
if "reward" in status:
short_status["rm"] = status["reward"]

if "critic_loss" in status:
short_status["cri"] = status["critic_loss"]
short_status["vals"] = status["values"]
short_status["cri_lr"] = status["critic_lr"]

if "ptx_loss" in status:
short_status["ptx"] = status["ptx_loss"]

status_list.append(status)
for k, v in status.items():
all_metrics[k].append(v)
minibatch_pbar.set_postfix(short_status)

for epoch in range(self.cfg.trainer.update_epochs_per_batch):
pbar = tqdm(
dataloader,
minibatch_pbar = tqdm(
minibatch_iterator,
desc=f"Policy Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]",
disable=not self.strategy.is_rank_0(),
)
for local_step, experience in enumerate(pbar):
status = self.forward_backward(experience, accumulation_steps)

if (local_step + 1) % accumulation_steps == 0:
grad_norm = self.optim_step()
if grad_norm is not None:
status["raw_grad_norm"] = grad_norm
for minibatch in minibatch_pbar:
microbatch_iterator = BatchIterator(
minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)
num_microbatches = len(microbatch_iterator)
microbatch_weight = 1.0 / num_microbatches

for microbatch_idx, microbatch in enumerate(microbatch_iterator):
microbatch_experience = BatchIterator.batch_to_experience(microbatch)
status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight)

# Record status for all but the last microbatch in the minibatch.
# The last microbatch should be recorded after the optimizer step.
if microbatch_idx < num_microbatches - 1:
if self.record_memory:
self.save_memory_snapshot(global_step, local_step)
record_status(status)

# Local step counts the number of processed microbatches.
local_step += 1

grad_norm = self.optim_step()
if grad_norm is not None:
status["raw_grad_norm"] = grad_norm

if self.record_memory:
self.save_memory_snapshot(global_step, local_step)

status["policy_lr"] = self.scheduler.get_last_lr()[0]

policy_update_steps += 1

# for DP
# TODO (sumanthrh): this assumes all workers are data parallel.
# We assume that outputs are replicated within tp or sp group, otherwise this is not correct.
status = self.strategy.all_reduce(status)

# weighted mean for kl
# TODO (sumanthrh): this weighted mean is no longer correct since we use the max response length in the batch.
# we can log this in the driver
# if "kl" in status:
# status["kl"] *= status["response_length"]
# status["kl"] /= status["response_length"]

short_status = {}

if "policy_loss" in status:
short_status = {
"pg": status["policy_loss"],
"glen": status["response_length"],
"policy_lr": status["policy_lr"],
"ent": status["policy_entropy"],
}
if "raw_grad_norm" in status:
short_status["grad_norm"] = status["raw_grad_norm"]
if "reward" in status:
short_status["rm"] = status["reward"]

if "critic_loss" in status:
short_status["cri"] = status["critic_loss"]
short_status["vals"] = status["values"]
short_status["cri_lr"] = status["critic_lr"]

if "ptx_loss" in status:
short_status["ptx"] = status["ptx_loss"]

status_list.append(status)
for k, v in status.items():
all_metrics[k].append(v)
pbar.set_postfix(short_status)
# Record status for the last microbatch in the minibatch.
record_status(status)

torch.distributed.barrier()
# not needed beyond status logging
all_metrics.pop("response_length", None)

status_mean = reduce_metrics(all_metrics)
status_mean["policy_update_steps"] = policy_update_steps / accumulation_steps
status_mean["policy_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch...


# should return an `TrainingOutputBatch`
output = TrainingOutputBatch()
Expand Down Expand Up @@ -882,10 +907,11 @@ def _normalize_mini_batch_size(self):
self.cfg.trainer.critic_mini_batch_size * self.cfg.generator.n_samples_per_prompt // dp_size
)

def forward_backward(self, experience: Experience, accumulation_steps: int) -> Dict[str, float]:
def forward_backward(self, experience: Experience, microbatch_weight: float) -> Dict[str, float]:
"""
Perform the forward and backward pass for one micro-batch.
"""
self.model.train()
experience.to_device(torch.cuda.current_device())

sequences = experience.sequences
Expand All @@ -911,7 +937,7 @@ def forward_backward(self, experience: Experience, accumulation_steps: int) -> D
config=self.cfg.trainer.algorithm,
loss_mask=loss_mask,
)
loss = loss / accumulation_steps
loss = loss * microbatch_weight
self.strategy.backward(loss, self.model, self.optimizer)

status = {
Expand Down Expand Up @@ -964,52 +990,64 @@ def save_hf_model(self, export_dir: str, tokenizer):
)

def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch:
dataloader = BatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
global_step = train_data.metadata["global_step"]
minibatch_iterator = BatchIterator(
train_data, sample_batch_size=self.critic_mini_batch_size_per_gpu, drop_last=False
)

torch.cuda.empty_cache()
self.model.train()
all_metrics = defaultdict(list)
num_minibatches = len(minibatch_iterator)
local_step = 0

micro_batches_per_mini_batch = (
self.critic_mini_batch_size_per_gpu // self.cfg.trainer.micro_train_batch_size_per_gpu
)
# The number of steps (over micro batches) to accumulate gradients before taking an optimizer step.
accumulation_steps = micro_batches_per_mini_batch
def record_status(status: Dict[str, float]):
status["critic_lr"] = self.scheduler.get_last_lr()[0]

# for DP
# TODO (sumanthrh): this assumes all workers are data parallel.
# We should get more accurate metrics with seq parallel or TP.
# There are metrics like entropy where we get average over local data size
status = self.strategy.all_reduce(status)

for k, v in status.items():
all_metrics[k].append(v)
minibatch_pbar.set_postfix(status)

all_metrics = defaultdict(list)
critic_update_steps = 0
for epoch in range(self.cfg.trainer.update_epochs_per_batch):
pbar = tqdm(
dataloader,
minibatch_pbar = tqdm(
minibatch_iterator,
desc=f"Critic Train epoch [{epoch + 1}/{self.cfg.trainer.update_epochs_per_batch}]",
disable=not self.strategy.is_rank_0(),
)
for local_step, experience in enumerate(pbar):
status = self.forward_backward(experience, accumulation_steps)
for minibatch in minibatch_pbar:
microbatch_iterator = BatchIterator(
minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)
num_microbatches = len(microbatch_iterator)
microbatch_weight = 1.0 / num_microbatches

for microbatch_idx, microbatch in enumerate(microbatch_iterator):
microbatch_experience = BatchIterator.batch_to_experience(microbatch)
status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight)

if (local_step + 1) % accumulation_steps == 0:
grad_norm = self.optim_step()
if grad_norm is not None:
status["raw_grad_norm"] = grad_norm
if microbatch_idx < num_microbatches - 1:
if self.record_memory:
self.save_memory_snapshot(global_step, local_step)
record_status(status)

status["critic_lr"] = self.scheduler.get_last_lr()[0]
critic_update_steps += 1
local_step += 1

# for DP
# TODO (sumanthrh): this assumes all workers are data parallel.
# We should get more accurate metrics with seq parallel or TP.
# There are metrics like entropy where we get average over local data size
status = self.strategy.all_reduce(status)
grad_norm = self.optim_step()
if grad_norm is not None:
status["raw_grad_norm"] = grad_norm

for k, v in status.items():
all_metrics[k].append(v)
pbar.set_postfix(status)
if self.record_memory:
self.save_memory_snapshot(global_step, local_step)
record_status(status)

torch.distributed.barrier()

status_mean = reduce_metrics(all_metrics)
status_mean["critic_update_steps"] = critic_update_steps / accumulation_steps
status_mean["critic_update_steps"] = num_minibatches * self.cfg.trainer.update_epochs_per_batch

output = TrainingOutputBatch()
output.metadata = {"train_status": status_mean}
Expand Down
6 changes: 2 additions & 4 deletions skyrl-train/skyrl_train/workers/worker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ def __len__(self):
def __iter__(self):
return self

def __next__(self) -> Experience:
def __next__(self) -> TrainingInputBatch:
Copy link
Collaborator

@erictang000 erictang000 Dec 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change also affects the use of BatchIterator for the megatron backend, which implements ppo_train differently FSDP/Deepspeed.

could you make sure the conversion to experience is also handled correctly for the megatron code path? Making sure one of these tests pass:

async def test_megatron_train(

is probably a good way to check this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I added a followup TODO to update the megatron worker's ppo loop as well. Made a minimal change for now to prevent this PR from getting too large.

try:
batch = next(self._iter)
exp = self.batch_to_experience(batch)
return exp
return next(self._iter)
except StopIteration:
self._iter = iter(self._chunks)
raise StopIteration
Expand Down
17 changes: 9 additions & 8 deletions skyrl-train/tests/cpu/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,8 @@ def create_test_worker(worker_class):
# Mock forward_backward and optim_step to track calls and verify accumulation behavior
policy_forward_backward_calls = []

def mock_policy_forward_backward(experience, accumulation_steps):
policy_forward_backward_calls.append({"accumulation_steps": accumulation_steps})
def mock_policy_forward_backward(experience, microbatch_weight):
policy_forward_backward_calls.append({"microbatch_weight": microbatch_weight})
return {"policy_loss": 0.5, "ppo_clip_ratio": 0.1, "policy_entropy": 2.0, "response_length": response_length}

policy_worker.forward_backward = mock_policy_forward_backward
Expand All @@ -568,6 +568,7 @@ def mock_policy_forward_backward(experience, accumulation_steps):
) # 6 // 2 = 3
# New logic: accumulation_steps = micro_batches_per_mini_batch (accumulate within mini-batch)
expected_accumulation_steps = micro_batches_per_mini_batch # Should be 3
expected_microbatch_weight = 1.0 / expected_accumulation_steps

# Run policy ppo_train with minimal mocking
with (
Expand All @@ -584,8 +585,8 @@ def mock_policy_forward_backward(experience, accumulation_steps):
# Verify accumulation_steps are consistent (should equal micro_batches_per_mini_batch)
for call in policy_forward_backward_calls:
assert (
call["accumulation_steps"] == expected_accumulation_steps
), f"PolicyWorker: Expected accumulation_steps={expected_accumulation_steps}, got {call['accumulation_steps']}"
call["microbatch_weight"] == expected_microbatch_weight
), f"PolicyWorker: Expected microbatch_weight={expected_microbatch_weight}, got {call['microbatch_weight']}"

# Verify result structure
assert "train_status" in result.metadata
Expand All @@ -602,8 +603,8 @@ def mock_policy_forward_backward(experience, accumulation_steps):
# Mock forward_backward and optim_step for critic
critic_forward_backward_calls = []

def mock_critic_forward_backward(experience, accumulation_steps):
critic_forward_backward_calls.append({"accumulation_steps": accumulation_steps})
def mock_critic_forward_backward(experience, microbatch_weight):
critic_forward_backward_calls.append({"microbatch_weight": microbatch_weight})
return {"critic_loss": 0.3, "values": 1.0}

critic_worker.forward_backward = mock_critic_forward_backward
Expand All @@ -627,8 +628,8 @@ def mock_critic_forward_backward(experience, accumulation_steps):
# Verify accumulation_steps are consistent for critic (should equal micro_batches_per_mini_batch)
for call in critic_forward_backward_calls:
assert (
call["accumulation_steps"] == expected_accumulation_steps
), f"CriticWorker: Expected accumulation_steps={expected_accumulation_steps}, got {call['accumulation_steps']}"
call["microbatch_weight"] == expected_microbatch_weight
), f"CriticWorker: Expected microbatch_weight={expected_microbatch_weight}, got {call['microbatch_weight']}"

# Verify result structure for critic
assert "train_status" in result.metadata
Expand Down