diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 4f85a6a80..0618735ba 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -17,8 +17,15 @@ from slime.utils.logging_utils import init_tracking from slime.utils.memory_utils import clear_memory, print_memory from slime.utils.metric_utils import compute_rollout_step -from slime.utils.misc import Box -from slime.utils.ppo_utils import compute_approx_kl, compute_gspo_kl, compute_opsm_mask, compute_policy_loss +from slime.utils.misc import load_function, Box +from slime.utils.ppo_utils import ( + build_opsm_inputs_from_log_probs, + compute_approx_kl, + compute_gspo_kl, + compute_opsm_mask, + compute_policy_loss, + vanilla_tis_function, +) from slime.utils.processing_utils import load_processor, load_tokenizer from slime.utils.profile_utils import TrainProfiler from slime.utils.timer import Timer, inverse_timer, timer, with_defer @@ -582,13 +589,16 @@ def _train_step(self, packed_batch, reported_accum, mbs_id, grad_accum): ppo_kl = old_log_probs - log_probs if self.args.use_opsm: - opsm_mask, opsm_clipfrac = compute_opsm_mask( - args=self.args, + opsm_inputs = build_opsm_inputs_from_log_probs( full_log_probs=[batch["cur_log_probs"] for batch in unpacked_batches], full_old_log_probs=[batch[old_log_prob_key] for batch in unpacked_batches], - advantages=[batch["advantages"] for batch in unpacked_batches], loss_masks=loss_masks, ) + opsm_mask, opsm_clipfrac = compute_opsm_mask( + args=self.args, + advantages=[batch["advantages"] for batch in unpacked_batches], + opsm_inputs=opsm_inputs, + ) if self.args.advantage_estimator == "gspo": ppo_kl = compute_gspo_kl( diff --git a/slime/backends/megatron_utils/cp_utils.py b/slime/backends/megatron_utils/cp_utils.py index 2e795d3d3..26807b689 100644 --- a/slime/backends/megatron_utils/cp_utils.py +++ b/slime/backends/megatron_utils/cp_utils.py @@ -50,6 +50,48 @@ def get_logits_and_tokens_offset_with_cp( return chunk_size, (chunk_0, chunk_1), (logits_0, logits_1), (token_0, token_1) +def get_chunked_loss_masks( + total_lengths: list[int], + response_lengths: list[int], + loss_masks: list[torch.Tensor], + qkv_format: str = "thd", + max_seq_lens: list[int] | None = None, +) -> tuple[list[torch.Tensor], list[int]]: + """Slice loss masks to the local CP segments and return chunk lengths.""" + + cp_size = mpu.get_context_parallel_world_size() + if cp_size == 1: + return loss_masks, response_lengths + + chunked_loss_masks: list[torch.Tensor] = [] + chunk_lengths: list[int] = [] + for i, (total_length, response_length, loss_mask) in enumerate(zip(total_lengths, response_lengths, loss_masks, strict=False)): + max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None + prompt_length = total_length - response_length + _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp( + total_length, + response_length, + qkv_format, + max_seq_len, + ) + + local_chunks: list[torch.Tensor] = [] + for start, end in tokens_offset: + local_start, local_end = start - prompt_length, end - prompt_length + if local_end > local_start: + local_chunks.append(loss_mask[local_start:local_end]) + + if local_chunks: + chunked_mask = torch.cat(local_chunks, dim=0) + else: + chunked_mask = loss_mask.new_zeros((0,)) + + chunked_loss_masks.append(chunked_mask) + chunk_lengths.append(chunked_mask.size(0)) + + return chunked_loss_masks, chunk_lengths + + def get_sum_of_sample_mean( total_lengths: list[int], response_lengths: list[int], @@ -62,59 +104,39 @@ def get_sum_of_sample_mean( Calculate correct sample mean for CP """ cp_size = mpu.get_context_parallel_world_size() + chunked_loss_masks, chunk_lengths = get_chunked_loss_masks( + total_lengths, + response_lengths, + loss_masks, + qkv_format, + max_seq_lens, + ) + if cp_size == 1: def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: - return sum( - [ - (x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1) - for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False) - ] - ) + return sum([(x_i * loss_mask_i).sum() / torch.clamp_min(loss_mask_i.sum(), 1) for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)]) def sum_of_token(x: torch.Tensor) -> torch.Tensor: - return sum( - [ - (x_i * loss_mask_i).sum() - for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False) - ] - ) + return sum([(x_i * loss_mask_i).sum() for x_i, loss_mask_i in zip(x.split(response_lengths, dim=0), loss_masks, strict=False)]) else: - cp_chunk_lengths = [] - chunked_loss_masks = [] - for i, (total_length, response_length, loss_mask) in enumerate( - zip(total_lengths, response_lengths, loss_masks, strict=False) - ): - max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None - prompt_length = total_length - response_length - _, _, _, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len - ) - loss_mask_0 = loss_mask[tokens_offset[0][0] - prompt_length : tokens_offset[0][1] - prompt_length] - loss_mask_1 = loss_mask[tokens_offset[1][0] - prompt_length : tokens_offset[1][1] - prompt_length] - chunked_loss_masks.append(torch.cat([loss_mask_0, loss_mask_1], dim=0)) - cp_chunk_lengths.append(chunked_loss_masks[i].size(0)) def sum_of_sample_mean(x: torch.Tensor) -> torch.Tensor: return sum( [ (x_i * chunked_loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) for x_i, chunked_loss_mask, loss_mask in zip( - x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, loss_masks, strict=False + x.split(chunk_lengths, dim=0), + chunked_loss_masks, + loss_masks, + strict=False, ) ] ) def sum_of_token(x: torch.Tensor) -> torch.Tensor: - return sum( - [ - (x_i * chunked_loss_mask).sum() - for x_i, chunked_loss_mask in zip( - x.split(cp_chunk_lengths, dim=0), chunked_loss_masks, strict=False - ) - ] - ) + return sum([(x_i * chunked_loss_mask).sum() for x_i, chunked_loss_mask in zip(x.split(chunk_lengths, dim=0), chunked_loss_masks, strict=False)]) return sum_of_sample_mean if not calculate_per_token_loss else sum_of_token @@ -231,7 +253,10 @@ def slice_log_prob_with_cp( prompt_length = total_length - response_length _, _, logits_offset, _ = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_token_len + total_length, + response_length, + qkv_format, + max_token_len, ) chunk_1 = log_prob[logits_offset[0][0] - (prompt_length - 1) : logits_offset[0][1] - (prompt_length - 1)] diff --git a/slime/backends/megatron_utils/loss.py b/slime/backends/megatron_utils/loss.py index 3ed323847..91fa7da19 100644 --- a/slime/backends/megatron_utils/loss.py +++ b/slime/backends/megatron_utils/loss.py @@ -3,12 +3,17 @@ from typing import Any import torch +import torch.distributed as dist +import torch.distributed.nn.functional as distnnf from megatron.core import mpu from torch.utils.checkpoint import checkpoint from slime.utils.distributed_utils import distributed_masked_whiten from slime.utils.misc import load_function from slime.utils.ppo_utils import ( + OpsmInputs, + build_opsm_inputs_from_log_probs, + build_opsm_inputs_from_seq_kls, calculate_log_probs_and_entropy, compute_approx_kl, compute_gspo_kl, @@ -19,9 +24,10 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) +from slime.utils.timer import timer from slime.utils.types import RolloutBatch -from .cp_utils import all_gather_with_cp, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean +from .cp_utils import get_chunked_loss_masks, get_logits_and_tokens_offset_with_cp, get_sum_of_sample_mean def get_responses( @@ -71,13 +77,11 @@ def get_responses( cp_size = mpu.get_context_parallel_world_size() end = 0 - for i, (tokens, total_length, response_length) in enumerate( - zip(unconcat_tokens, total_lengths, response_lengths, strict=False) - ): + for i, (tokens, total_length, response_length) in enumerate(zip(unconcat_tokens, total_lengths, response_lengths, strict=False)): max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None - if cp_size == 1: if qkv_format == "bshd": + assert max_seq_len is not None end = max_seq_len * i + total_length start = end - response_length else: @@ -88,7 +92,10 @@ def get_responses( else: # TODO: this is super ugly... do better abstraction. chunk_size, chunks_offset, logits_offset, tokens_offset = get_logits_and_tokens_offset_with_cp( - total_length, response_length, qkv_format, max_seq_len + total_length, + response_length, + qkv_format, + max_seq_len, ) logits_0, logits_1 = logits[end : end + chunk_size], logits[end + chunk_size : end + 2 * chunk_size] @@ -219,6 +226,54 @@ def get_values( } +def _compute_seq_kls_with_cp( + log_probs: list[torch.Tensor], + old_log_probs: list[torch.Tensor], + loss_masks: list[torch.Tensor], + total_lengths: list[int], + response_lengths: list[int], + *, + cp_group: dist.ProcessGroup, + track_grad: bool, + qkv_format: str = "thd", + max_seq_lens: list[int] | None = None, +) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + """Compute per-sequence KL scalars with minimal CP communication. + + The numerator is accumulated from the CP-local response segments and then + reduced across the context-parallel group. The denominator uses the full + (non-chunked) `loss_masks`, which are identical on all CP ranks. + """ + + chunked_loss_masks, _ = get_chunked_loss_masks( + total_lengths, + response_lengths, + loss_masks, + qkv_format=qkv_format, + max_seq_lens=max_seq_lens, + ) + + if not loss_masks: + return [], chunked_loss_masks + + def compute_local_numerators() -> torch.Tensor: + return torch.stack([((old_log_prob - log_prob) * chunked_loss_mask).sum() for log_prob, old_log_prob, chunked_loss_mask in zip(log_probs, old_log_probs, chunked_loss_masks, strict=False)]) + + if track_grad: + local_numerators = compute_local_numerators() + global_numerators = distnnf.all_reduce(local_numerators, group=cp_group) + else: + with torch.no_grad(): + local_numerators = compute_local_numerators() + dist.all_reduce(local_numerators, group=cp_group) + global_numerators = local_numerators + + denominators = torch.stack([torch.clamp_min(loss_mask.sum(), 1) for loss_mask in loss_masks]).to(global_numerators) + seq_kls = list((global_numerators / denominators).unbind(dim=0)) + + return seq_kls, chunked_loss_masks + + def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) -> None: """Compute advantages and returns in-place based on `args.advantage_estimator`. @@ -248,6 +303,8 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) loss_masks: list[torch.Tensor] = rollout_data.get("loss_masks") total_lengths: list[int] = rollout_data.get("total_lengths") max_seq_lens: list[int] | None = rollout_data.get("max_seq_lens", None) + cp_size = mpu.get_context_parallel_world_size() + cp_rank = mpu.get_context_parallel_rank() # return when not the last pp stage. if log_probs is None and values is None: @@ -277,15 +334,12 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) old_rewards = rewards rewards = [] kl_coef = -args.kl_coef - cp_rank = mpu.get_context_parallel_rank() for reward, k in zip(old_rewards, kl, strict=False): k *= kl_coef if cp_rank == 0: k[-1] += reward rewards.append(k) - advantages, returns = get_advantages_and_returns_batch( - total_lengths, response_lengths, values, rewards, args.gamma, args.lambd - ) + advantages, returns = get_advantages_and_returns_batch(total_lengths, response_lengths, values, rewards, args.gamma, args.lambd) elif args.advantage_estimator == "reinforce_plus_plus": rewards = torch.tensor(rewards, dtype=torch.float32, device=kl[0].device) @@ -316,14 +370,8 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) response_lengths = rollout_data.get("response_lengths") device = student_log_probs[0].device teacher_log_probs = [t_log_prob.to(device=device) for t_log_prob in teacher_log_probs] - teacher_log_probs = [ - t_log_prob[-response_length:] - for t_log_prob, response_length in zip(teacher_log_probs, response_lengths, strict=False) - ] - advantages = [ - teacher_log_prob - student_log_prob - for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs, strict=False) - ] + teacher_log_probs = [t_log_prob[-response_length:] for t_log_prob, response_length in zip(teacher_log_probs, response_lengths, strict=False)] + advantages = [teacher_log_prob - student_log_prob for teacher_log_prob, student_log_prob in zip(teacher_log_probs, student_log_probs, strict=False)] returns = advantages else: @@ -332,48 +380,20 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) # TODO: OpenRLHF always does advantages normalization but veRL doesn't seem to do it. if args.normalize_advantages: all_advs = torch.cat(advantages) - cp_size = mpu.get_context_parallel_world_size() if cp_size == 1: all_masks = torch.cat(loss_masks) else: - mask_chunks = [] - for i in range(len(advantages)): - total_len = total_lengths[i] - response_len = response_lengths[i] - prompt_len = total_len - response_len - max_seq_len = max_seq_lens[i] if max_seq_lens is not None else None - - _, _, _, token_offsets = get_logits_and_tokens_offset_with_cp( - total_len, response_len, args.qkv_format, max_seq_len - ) - - # Convert global offsets to response-space offsets - s0, e0 = token_offsets[0] - s1, e1 = token_offsets[1] - res_s0, res_e0 = max(0, s0 - prompt_len), max(0, e0 - prompt_len) - res_s1, res_e1 = max(0, s1 - prompt_len), max(0, e1 - prompt_len) - - local_mask_parts = [] - full_mask = loss_masks[i] - if res_e0 > res_s0: - local_mask_parts.append(full_mask[res_s0:res_e0]) - if res_e1 > res_s1: - local_mask_parts.append(full_mask[res_s1:res_e1]) - - # Concatenate the parts to form the final mask chunk for this rank and this sequence - local_mask_chunk = ( - torch.cat(local_mask_parts) - if local_mask_parts - else torch.tensor([], device=all_advs.device, dtype=full_mask.dtype) - ) - mask_chunks.append(local_mask_chunk) - + mask_chunks, _ = get_chunked_loss_masks( + total_lengths, + response_lengths, + loss_masks, + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, + ) all_masks = torch.cat(mask_chunks) if all_masks.numel() > 0: - assert ( - all_advs.size() == all_masks.size() - ), f"Shape mismatch before whitening: advantages {all_advs.size()}, masks {all_masks.size()}" + assert all_advs.size() == all_masks.size(), f"Shape mismatch before whitening: advantages {all_advs.size()}, masks {all_masks.size()}" dp_group = mpu.get_data_parallel_group() whitened_advs_flat = distributed_masked_whiten( @@ -426,9 +446,7 @@ def icepop_function( old_log_probs = torch.cat(train_log_probs, dim=0) ice_ratio = torch.exp(old_log_probs - rollout_log_probs) ice_abs = (torch.exp(old_log_probs - rollout_log_probs) - 1).abs() - ice_weight = torch.where( - (ice_ratio >= args.tis_clip_low) & (ice_ratio <= args.tis_clip), ice_ratio, torch.zeros_like(ice_ratio) - ) + ice_weight = torch.where((ice_ratio >= args.tis_clip_low) & (ice_ratio <= args.tis_clip), ice_ratio, torch.zeros_like(ice_ratio)) ice_clipfrac = (ice_weight != ice_ratio).float() metrics = { "tis": ice_ratio.clone().detach(), @@ -448,10 +466,10 @@ def policy_loss_function( """Compute policy loss (PPO/GSPO) and metrics. Computes current log-probabilities and entropy from model logits, then - calculates PPO-style clipped policy gradient loss. For GSPO, gathers - full sequences via context-parallel all-gather before computing per-sample - KL. Optionally applies TIS (Truncated Importance Sampling) correction and - adds KL loss term if configured. + calculates PPO-style clipped policy gradient loss. For GSPO, computes + per-sample KL scalars and (when CP is enabled) avoids all-gathering full + sequences by reducing only the per-sample numerator. Optionally applies + TIS (importance sampling) correction and adds KL loss term if configured. Args: args: Configuration controlling advantage estimator, clipping thresholds, @@ -474,6 +492,7 @@ def policy_loss_function( response_lengths = batch["response_lengths"] total_lengths = batch["total_lengths"] + loss_masks = batch["loss_masks"] max_seq_lens = batch.get("max_seq_lens", None) log_probs_and_entropy = get_log_probs_and_entropy( @@ -488,45 +507,79 @@ def policy_loss_function( log_probs = log_probs_and_entropy["log_probs"] - # Pre-gather log probs if needed by OPSM or GSPO to avoid duplicate gathering + # Pre-compute sequence-level KL inputs for OPSM/GSPO (CP avoids full all-gather). need_full_log_probs = args.use_opsm or args.advantage_estimator == "gspo" full_log_probs = None full_old_log_probs = None + seq_kls: list[torch.Tensor] | None = None + chunked_loss_masks: list[torch.Tensor] | None = None + opsm_inputs: OpsmInputs | None = None + cp_size = mpu.get_context_parallel_world_size() + cp_group = mpu.get_context_parallel_group() if cp_size > 1 else None if need_full_log_probs: - full_log_probs = [ - all_gather_with_cp(log_prob, total_length, response_length) - for log_prob, total_length, response_length in zip( - log_probs, total_lengths, response_lengths, strict=False + with timer("cp_seq_kl_prep"): + if cp_size > 1: + assert cp_group is not None + seq_kls, chunked_loss_masks = _compute_seq_kls_with_cp( + log_probs, + old_log_probs, + loss_masks, + total_lengths, + response_lengths, + cp_group=cp_group, + track_grad=args.advantage_estimator == "gspo", + qkv_format=args.qkv_format, + max_seq_lens=max_seq_lens, + ) + else: + full_log_probs = log_probs + full_old_log_probs = old_log_probs + + if args.use_opsm: + if seq_kls is not None and chunked_loss_masks is not None: + opsm_inputs = build_opsm_inputs_from_seq_kls( + seq_kls=seq_kls, + loss_masks=loss_masks, + chunked_loss_masks=chunked_loss_masks, ) - ] - full_old_log_probs = [ - all_gather_with_cp(old_log_prob, total_length, response_length) - for old_log_prob, total_length, response_length in zip( - old_log_probs, total_lengths, response_lengths, strict=False + else: + assert full_log_probs is not None and full_old_log_probs is not None + opsm_inputs = build_opsm_inputs_from_log_probs( + full_log_probs=full_log_probs, + full_old_log_probs=full_old_log_probs, + loss_masks=loss_masks, ) - ] # Compute OPSM mask if enabled if args.use_opsm: + assert opsm_inputs is not None, "OPSM inputs must be built before masking" opsm_mask, opsm_clipfrac = compute_opsm_mask( args=args, - full_log_probs=full_log_probs, - full_old_log_probs=full_old_log_probs, advantages=batch["advantages"], - loss_masks=batch["loss_masks"], + opsm_inputs=opsm_inputs, + cp_group=cp_group, ) # Compute KL divergence (GSPO uses sequence-level KL, others use per-token KL) if args.advantage_estimator == "gspo": - ppo_kl = compute_gspo_kl( - full_log_probs=full_log_probs, - full_old_log_probs=full_old_log_probs, - local_log_probs=log_probs, - loss_masks=batch["loss_masks"], - ) - old_log_probs = torch.cat(old_log_probs, dim=0) - log_probs = torch.cat(log_probs, dim=0) + if cp_size > 1: + assert seq_kls is not None + ppo_kl = torch.cat( + [seq_kl.expand_as(local_log_prob) for seq_kl, local_log_prob in zip(seq_kls, log_probs, strict=False)], + dim=0, + ) + old_log_probs = torch.cat(old_log_probs, dim=0) + log_probs = torch.cat(log_probs, dim=0) + else: + ppo_kl = compute_gspo_kl( + full_log_probs=full_log_probs, + full_old_log_probs=full_old_log_probs, + local_log_probs=log_probs, + loss_masks=loss_masks, + ) + old_log_probs = torch.cat(old_log_probs, dim=0) + log_probs = torch.cat(log_probs, dim=0) else: old_log_probs = torch.cat(old_log_probs, dim=0) log_probs = torch.cat(log_probs, dim=0) @@ -557,7 +610,7 @@ def policy_loss_function( "pg_loss": pg_loss, "train_log_probs": batch["log_probs"], "rollout_log_probs": batch["rollout_log_probs"], - "loss_masks": batch["loss_masks"], + "loss_masks": loss_masks, "total_lengths": total_lengths, "response_lengths": response_lengths, } @@ -584,9 +637,7 @@ def policy_loss_function( custom_pg_loss_reducer_func = load_function(args.custom_pg_loss_reducer_function_path) # Determine which loss_masks to use for pg_loss reducer pg_loss_masks = modified_response_masks if (args.get_mismatch_metrics or args.use_tis) else batch["loss_masks"] - pg_loss_reducer = custom_pg_loss_reducer_func( - total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss - ) + pg_loss_reducer = custom_pg_loss_reducer_func(total_lengths, response_lengths, pg_loss_masks, args.calculate_per_token_loss) else: pg_loss_reducer = sum_of_sample_mean @@ -617,10 +668,6 @@ def policy_loss_function( loss = loss + args.kl_loss_coef * kl_loss - # make sure the gradient could backprop correctly. - if log_probs.numel() == 0: - loss += 0 * logits.sum() - train_rollout_logprob_abs_diff = None if "rollout_log_probs" in batch and batch["rollout_log_probs"]: rollout_log_probs = torch.cat(batch["rollout_log_probs"], dim=0) @@ -652,6 +699,10 @@ def policy_loss_function( if args.use_opsm: reported_loss["opsm_clipfrac"] = opsm_clipfrac + # Ensure gradients propagate even when local sequences are empty + if log_probs.numel() == 0: + loss = loss + 0 * logits.sum() + return loss, reported_loss @@ -701,7 +752,7 @@ def value_loss_function( loss = sum_of_sample_mean(loss) values_clipfrac = sum_of_sample_mean(values_clipfrac.float()) - # make sure the gradient could backprop correctly. + # Ensure gradients propagate even when local sequences are empty if values.numel() == 0: loss += 0 * values.sum() @@ -825,9 +876,7 @@ def loss_function( # Here we need to divide by cp_size because to cancel the multiply in Megatron. global_batch_size = batch.get("dynamic_global_batch_size", args.global_batch_size) if not args.calculate_per_token_loss: - loss = ( - loss * num_microbatches / global_batch_size * mpu.get_data_parallel_world_size(with_context_parallel=True) - ) + loss = loss * num_microbatches / global_batch_size * mpu.get_data_parallel_world_size(with_context_parallel=True) else: loss = loss * mpu.get_context_parallel_world_size() diff --git a/slime/rollout/data_source.py b/slime/rollout/data_source.py index fa14c65f0..56ee5db9c 100644 --- a/slime/rollout/data_source.py +++ b/slime/rollout/data_source.py @@ -188,16 +188,19 @@ def add_samples(self, samples: list[list[Sample]]): """ Add a sample group to buffer. """ + samples = [list(sample_group) for sample_group in samples] + if not samples: return - assert isinstance(samples, list), f"samples must be a list, got {type(samples)}" - assert isinstance(samples[0], list), f"the elements of samples must be list, got {type(samples[0])}" - for i in range(0, len(samples)): - assert ( - len(samples[i]) == self.args.n_samples_per_prompt - ), f"the length of the elements of samples must be equal to n_samples_per_prompt, got {len(samples[i])} != {self.args.n_samples_per_prompt}" - group = samples[i] # type: ignore - self.buffer.append(group) + + for i, sample_group in enumerate(samples): + if len(sample_group) != self.args.n_samples_per_prompt: + raise ValueError( + f"Sample group {i} has size {len(sample_group)}; expected {self.args.n_samples_per_prompt}." + ) + if not all(isinstance(sample, Sample) for sample in sample_group): + raise TypeError(f"Sample group {i} contains non-Sample entries: {sample_group}.") + self.buffer.extend(samples) # TODO remove def update_metadata(self, metadata: dict): diff --git a/slime/utils/ppo_utils.py b/slime/utils/ppo_utils.py index 2404883ab..637aea68d 100644 --- a/slime/utils/ppo_utils.py +++ b/slime/utils/ppo_utils.py @@ -2,6 +2,7 @@ # and https://github.com/OpenRLHF/OpenRLHF/blob/10c733694ed9fbb78a0a2ff6a05efc7401584d46/openrlhf/trainer/ppo_utils/experience_maker.py from argparse import Namespace +from dataclasses import dataclass import torch import torch.distributed as dist @@ -51,44 +52,103 @@ def compute_approx_kl( return kl -def compute_opsm_mask( - args: Namespace, +@dataclass +class OpsmInputs: + seq_kls: list[torch.Tensor] + loss_masks: list[torch.Tensor] + effective_loss_masks: list[torch.Tensor] + + +def build_opsm_inputs_from_log_probs( full_log_probs: list[torch.Tensor], full_old_log_probs: list[torch.Tensor], - advantages: list[torch.Tensor], loss_masks: list[torch.Tensor], +) -> OpsmInputs: + seq_kls = [ + ((full_old_log_prob - full_log_prob) * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) + for full_log_prob, full_old_log_prob, loss_mask in zip( + full_log_probs, full_old_log_probs, loss_masks, strict=False + ) + ] + + return OpsmInputs(seq_kls=seq_kls, loss_masks=loss_masks, effective_loss_masks=loss_masks) + + +def build_opsm_inputs_from_seq_kls( + seq_kls: list[torch.Tensor], + loss_masks: list[torch.Tensor], + chunked_loss_masks: list[torch.Tensor], +) -> OpsmInputs: + assert len(seq_kls) == len(loss_masks) == len(chunked_loss_masks), ( + "seq_kls, loss_masks, and chunked_loss_masks must have the same length" + ) + return OpsmInputs(seq_kls=seq_kls, loss_masks=loss_masks, effective_loss_masks=chunked_loss_masks) + + +def compute_opsm_mask( + args: Namespace, + advantages: list[torch.Tensor], + *, + opsm_inputs: OpsmInputs, + cp_group: dist.ProcessGroup | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Compute Off-Policy Sequence Masking (OPSM) mask. Args: args: Configuration containing `opsm_delta` threshold. - full_log_probs: Current policy log-probs per sample. - full_old_log_probs: Old policy log-probs per sample. advantages: Advantage values per sample. - loss_masks: Loss masks per sample. + opsm_inputs: Precomputed sequence-level KLs and aligned loss masks. + cp_group: Optional context-parallel process group for distributed + reductions when context parallelism is enabled. Returns: Tuple of `(opsm_mask, opsm_clipfrac)` where `opsm_mask` is a concatenated tensor of per-token masks and - `opsm_clipfrac` is the count of masked sequences. + `opsm_clipfrac` is the sum of per-sequence masked-token fractions. """ - opsm_mask_list = [] + cp_size = dist.get_world_size(cp_group) if cp_group is not None else 1 + + opsm_mask_list: list[torch.Tensor] = [] device = advantages[0].device - opsm_clipfrac = torch.tensor(0.0, device=device) + opsm_delta = args.opsm_delta - for full_log_prob, full_old_log_prob, advantage, loss_mask in zip( - full_log_probs, full_old_log_probs, advantages, loss_masks, strict=False + masked_token_counts: list[torch.Tensor] = [] + total_token_counts: list[torch.Tensor] = [] + + for advantage, full_loss_mask, effective_loss_mask, seq_kl in zip( + advantages, + opsm_inputs.loss_masks, + opsm_inputs.effective_loss_masks, + opsm_inputs.seq_kls, + strict=False, ): - # Calculate sequence-level KL - seq_kl = ((full_old_log_prob - full_log_prob) * loss_mask).sum() / torch.clamp_min(loss_mask.sum(), 1) + if advantage.numel() != effective_loss_mask.numel(): + raise ValueError( + "OPSM requires `advantages` and `effective_loss_masks` to be aligned. " + f"Got {advantage.numel()} vs {effective_loss_mask.numel()}." + ) + + masked_tokens = ((advantage < 0) & (seq_kl > opsm_delta)).to(dtype=effective_loss_mask.dtype) + masked_tokens = masked_tokens * effective_loss_mask + opsm_mask_list.append(1 - masked_tokens) - # Create mask: 0 if (advantage < 0 and seq_kl > delta), else 1 - mask = ((advantage < 0) & (seq_kl > args.opsm_delta)).float() - opsm_clipfrac += mask.sum() / torch.clamp_min(loss_mask.sum(), 1) + masked_token_counts.append(masked_tokens.sum()) + total_token_counts.append(torch.clamp_min(full_loss_mask.sum().to(masked_tokens), 1)) - opsm_mask_list.append(1 - mask) + opsm_mask = torch.cat(opsm_mask_list, dim=0) if opsm_mask_list else torch.tensor([], device=device) + opsm_clipfrac = torch.tensor(0.0, device=device) + + if masked_token_counts: + masked_tokens = torch.stack(masked_token_counts) + total_tokens = torch.stack(total_token_counts) + + if cp_size > 1: + dist.all_reduce(masked_tokens, group=cp_group) + + opsm_clipfrac = (masked_tokens / total_tokens).sum() + if cp_size > 1: + opsm_clipfrac = opsm_clipfrac / cp_size - opsm_mask = torch.cat(opsm_mask_list, dim=0) return opsm_mask, opsm_clipfrac @@ -409,9 +469,7 @@ def get_advantages_and_returns_batch( full_values_list = [] full_rewards_list = [] - for total_len, resp_len, v, r in zip( - total_lengths, response_lengths, values_list, rewards_list, strict=False - ): + for total_len, resp_len, v, r in zip(total_lengths, response_lengths, values_list, rewards_list, strict=False): full_v = all_gather_with_cp(v, total_len, resp_len) full_r = all_gather_with_cp(r, total_len, resp_len) full_values_list.append(full_v) @@ -455,11 +513,7 @@ def get_advantages_and_returns_batch( from slime.backends.megatron_utils.cp_utils import slice_log_prob_with_cp for total_len, resp_len, adv_row, ret_row in zip( - total_lengths, - response_lengths, - full_advantages, - full_returns, - strict=False, + total_lengths, response_lengths, full_advantages, full_returns, strict=False ): adv_full = adv_row # shape = [resp_len_i padded to max_len] ret_full = ret_row @@ -658,13 +712,13 @@ def calculate_log_probs_and_entropy(logits, tokens, tp_group, with_entropy: bool tokens_chunks = tokens.chunk(num_chunks, dim=0) logits_chunks = logits.chunk(num_chunks, dim=0) log_probs = [] - for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + for tokens_chunk, logits_chunk in zip(tokens_chunks, logits_chunks, strict=False): log_prob = compute_log_probs(logits_chunk.clone(), tokens_chunk, tp_group) log_probs.append(log_prob) log_prob = torch.cat(log_probs, dim=0) if with_entropy: entropys = [] - for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=True): + for _, logits_chunk in zip(tokens_chunks, logits_chunks, strict=False): entropy = compute_entropy_from_logits(logits_chunk.clone(), tp_group) entropys.append(entropy) entropy = torch.cat(entropys, dim=0)