diff --git a/.github/workflows/docker/docker-compose.yaml b/.github/workflows/docker/docker-compose.yaml index c854771eb6..3ac2071f79 100644 --- a/.github/workflows/docker/docker-compose.yaml +++ b/.github/workflows/docker/docker-compose.yaml @@ -8,6 +8,7 @@ services: - RAY_ADDRESS=auto - TRINITY_CHECKPOINT_ROOT_DIR=/mnt/checkpoints - TRINITY_TASKSET_PATH=/mnt/data + - TRINITY_EVAL_TASKSET_PATH=/mnt/data - TRINITY_SFT_DATASET_PATH=/mnt/data - TRINITY_MODEL_PATH=/mnt/models/Qwen3-0.6B - TRINITY_API_MODEL_PATH=/mnt/models/Qwen3-1.7B diff --git a/examples/rec_gsm8k/README.md b/examples/rec_gsm8k/README.md new file mode 100644 index 0000000000..c0d9376db1 --- /dev/null +++ b/examples/rec_gsm8k/README.md @@ -0,0 +1,221 @@ +# Example: REC on GSM8k dataset + +This example shows the usage of REC on the [GSM8k dataset](https://huggingface.co/datasets/openai/gsm8k). + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config file is located in [`gsm8k.yaml`](gsm8k.yaml). + +# Group-relative REINFORCE Families +This folder provides **example configurations** for running different group-relative REINFORCE families within Trinity-RFT. + +It includes three major families: + +- **REC family** (clipping + importance sampling) +- **REP family** (regularization-based variants) +- **RED family** (data-distribution shaping strategies) + +We also provide baseline implementations such as **Vanilla REINFORCE** and **GRPO**. + +All algorithms are instantiated through modular YAML configs for easy reproduction and extension. + +# Summary Table 📝 + +| Family | Variants | Key Idea | +| ------------- | ----------------------------------------------- | ----------------------------------- | +| **Baselines** | REINFORCE, GRPO | Standard references | +| **REC** | OneSide-NoIS, OneSide-IS, TwoSide-IS, Ring-NoIS | Clipping + importance sampling | +| **REP** | AsymRE, OPMD | Regularization | +| **RED** | Drop, Weight | Data-distribution shaping | + + + +# Instantiations + +## Baselines + +### REINFORCE +Vanilla REINFORCE with group mean as baseline. + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "none" # no clipping + weight: "none" # uniform weighting for samples + temp: 1.0 + regularizer: "none" # no regularizer + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` + +### GRPO +GRPO implemented with zero KL regularizer. Regularization can be enabled via `kl_loss_fn` and `kl_loss_fn_args`. + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "one-side" + weight: "importance_sampling" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: true + kl_loss_fn: 'k2' + kl_loss_fn_args: + kl_coef: 0.0 + +``` + +## REC family +Variants of clipping and importance-sampling strategies. +- REC-OneSide-NoIS +- REC-OneSide-IS +- REC-TwoSide-IS +- REC-Ring-NoIS + +### REC-OneSide-NoIS + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "one-side" + weight: "none" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` + +### REC-OneSide-IS + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "one-side" + weight: "importance_sampling" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` + +### REC-TwoSide-IS + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "two-side" + weight: "importance_sampling" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` +### REC-Ring-NoIS + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + epsilon_low_prime: 0.6 + epsilon_high_prime: 2.0 + clip_mode: "ring" + weight: "none" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` + +## REP family + +Regularization-based algorithms. +- AsymRE (forward KL regularization) +- Kimi’s OPMD (k2 regularizer) + +### AsymRE + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + clip_mode: "none" + weight: "none" + temp: 1.0 + regularizer: "forward-kl" + regularizer_coef: 0.1 + advantage_fn_args: + std_normalize: false +``` + + +### Kimi's OPMD + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + clip_mode: "none" + weight: "none" + regularizer: "k2" + regularizer_coef: 0.1 + advantage_fn_args: + std_normalize: false +``` + +## RED family +Data-distribution shaping variants. +- RED-Drop (drop extra negative examples to balance the positive examples v.s. negative examples) +- RED-Weight (advantage-weighting strategy) + +### RED-Drop + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + clip_mode: "none" + weight: "none" + regularizer: "none" + advantage_fn_args: + std_normalize: false + drop: "balance" +``` + + +### RED-Weight + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + clip_mode: "none" + weight: "advantage" + regularizer: "none" + temp: 1.0 + advantage_fn_args: + std_normalize: false +``` diff --git a/examples/rec_gsm8k/gsm8k.yaml b/examples/rec_gsm8k/gsm8k.yaml new file mode 100644 index 0000000000..98abf6b90e --- /dev/null +++ b/examples/rec_gsm8k/gsm8k.yaml @@ -0,0 +1,85 @@ +# Configuration file for the REC GSM8k project. +project: "Trinity-RFT-GSM8K" +name: rec_gsm8k +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +mode: both +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} + max_response_tokens: 1024 + max_model_len: 1280 +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "none" + weight: "none" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false + repeat_times: 8 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_steps: 100 + batch_size: 96 + explorer_input: + taskset: + name: gsm8k + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} + split: train + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + eval_tasksets: + - name: gsm8k-eval + storage_type: file + path: ${oc.env:TRINITY_EVAL_TASKSET_PATH} + split: test + format: + prompt_key: question + response_key: answer + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: gsm8k_buffer + storage_type: queue +explorer: + eval_interval: 20 + runner_num: 64 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: nccl + sync_interval: 20 + sync_timeout: 1200 + sync_offset: 0 +trainer: + trainer_type: verl + save_interval: 100 + trainer_config: + actor_rollout_ref: + model: + use_remove_padding: true + actor: + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 16384 + ulysses_sequence_parallel_size: 1 + optim: + lr: 1e-6 + ref: + log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size diff --git a/examples/rec_math/README.md b/examples/rec_math/README.md new file mode 100644 index 0000000000..8cc79050b8 --- /dev/null +++ b/examples/rec_math/README.md @@ -0,0 +1,221 @@ +# Example: REC on MATH dataset + +This example shows the usage of REC on the [MATH dataset](https://huggingface.co/datasets/nlile/hendrycks-MATH-benchmark). + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_reasoning_basic.md). + +The config file is located in [`math.yaml`](math.yaml). + +# Group-relative REINFORCE Families +This folder provides **example configurations** for running different group-relative REINFORCE families within Trinity-RFT. + +It includes three major families: + +- **REC family** (clipping + importance sampling) +- **REP family** (regularization-based variants) +- **RED family** (data-distribution shaping strategies) + +We also provide baseline implementations such as **Vanilla REINFORCE** and **GRPO**. + +All algorithms are instantiated through modular YAML configs for easy reproduction and extension. + +# Summary Table 📝 + +| Family | Variants | Key Idea | +| ------------- | ----------------------------------------------- | ----------------------------------- | +| **Baselines** | REINFORCE, GRPO | Standard references | +| **REC** | OneSide-NoIS, OneSide-IS, TwoSide-IS, Ring-NoIS | Clipping + importance sampling | +| **REP** | AsymRE, OPMD | Regularization | +| **RED** | Drop, Weight | Data-distribution shaping | + + + +# Instantiations + +## Baselines + +### REINFORCE +Vanilla REINFORCE with group mean as baseline. + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "none" # no clipping + weight: "none" # uniform weighting for samples + temp: 1.0 + regularizer: "none" # no regularizer + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` + +### GRPO +GRPO implemented with zero KL regularizer. Regularization can be enabled via `kl_loss_fn` and `kl_loss_fn_args`. + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "one-side" + weight: "importance_sampling" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: true + kl_loss_fn: 'k2' + kl_loss_fn_args: + kl_coef: 0.0 + +``` + +## REC family +Variants of clipping and importance-sampling strategies. +- REC-OneSide-NoIS +- REC-OneSide-IS +- REC-TwoSide-IS +- REC-Ring-NoIS + +### REC-OneSide-NoIS + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "one-side" + weight: "none" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` + +### REC-OneSide-IS + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "one-side" + weight: "importance_sampling" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` + +### REC-TwoSide-IS + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "two-side" + weight: "importance_sampling" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` +### REC-Ring-NoIS + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + epsilon_low_prime: 0.6 + epsilon_high_prime: 2.0 + clip_mode: "ring" + weight: "none" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + advantage_fn_args: + std_normalize: false +``` + +## REP family + +Regularization-based algorithms. +- AsymRE (forward KL regularization) +- Kimi’s OPMD (k2 regularizer) + +### AsymRE + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + clip_mode: "none" + weight: "none" + temp: 1.0 + regularizer: "forward-kl" + regularizer_coef: 0.1 + advantage_fn_args: + std_normalize: false +``` + + +### Kimi's OPMD + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + clip_mode: "none" + weight: "none" + regularizer: "k2" + regularizer_coef: 0.1 + advantage_fn_args: + std_normalize: false +``` + +## RED family +Data-distribution shaping variants. +- RED-Drop (drop extra negative examples to balance the positive examples v.s. negative examples) +- RED-Weight (advantage-weighting strategy) + +### RED-Drop + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + clip_mode: "none" + weight: "none" + regularizer: "none" + advantage_fn_args: + std_normalize: false + drop: "balance" +``` + + +### RED-Weight + +``` +algorithm: + algorithm_type: rec + policy_loss_fn_args: + clip_mode: "none" + weight: "advantage" + regularizer: "none" + temp: 1.0 + advantage_fn_args: + std_normalize: false +``` diff --git a/examples/rec_math/math.yaml b/examples/rec_math/math.yaml new file mode 100644 index 0000000000..d8cca9328c --- /dev/null +++ b/examples/rec_math/math.yaml @@ -0,0 +1,90 @@ +project: Trinity-RFT-rec_math +name: rec_math +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +mode: both +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} + max_response_tokens: 2048 + max_model_len: 2048 +algorithm: + algorithm_type: rec + policy_loss_fn_args: + epsilon_low: 0.2 + epsilon_high: 0.2 + epsilon_high_prime: 0.4 + epsilon_low_prime: 0.4 + clip_mode: none + weight: none + advantage_fn_args: + std_normalize: false + repeat_times: 8 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_steps: 200 + batch_size: 16 + explorer_input: + taskset: + name: math + storage_type: file + path: ${oc.env:TRINITY_TASKSET_PATH} + format: + prompt_key: problem + response_key: solution + rollout_args: + temperature: 1.0 + top_p: 1.0 + logprobs: 0 + eval_tasksets: + - name: math + storage_type: file + path: ${oc.env:TRINITY_EVAL_TASKSET_PATH} + split: test + format: + prompt_key: problem + response_key: solution + rollout_args: + temperature: 0.1 + top_p: 0.95 + default_workflow_type: math_boxed_workflow + default_reward_fn_type: math_boxed_reward + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue +explorer: + eval_interval: 500 + runner_num: 64 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + max_prompt_tokens: 1024 + max_response_tokens: 2048 + seed: 42 +synchronizer: + sync_method: nccl + sync_interval: 1 + sync_timeout: 3600 + sync_offset: 0 +trainer: + trainer_type: verl + save_interval: 100 + trainer_config: + actor_rollout_ref: + model: + use_remove_padding: true + actor: + use_dynamic_bsz: true + ppo_max_token_len_per_gpu: 16384 + ulysses_sequence_parallel_size: 1 + optim: + lr: 6e-8 + ref: + log_prob_use_dynamic_bsz: ${trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${trainer.trainer_config.actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${trainer.trainer_config.actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 52e316895a..f8349062f1 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -16,6 +16,7 @@ OPMDGroupAdvantage, ) from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn +from trinity.algorithm.advantage_fn.rec_advantage import RECGroupedAdvantage from trinity.algorithm.advantage_fn.reinforce_advantage import REINFORCEGroupAdvantage from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import ( REINFORCEPLUSPLUSAdvantageFn, @@ -38,4 +39,5 @@ "OPMDGroupAdvantage", "REINFORCEGroupAdvantage", "ASYMREAdvantageFn", + "RECGroupedAdvantage", ] diff --git a/trinity/algorithm/advantage_fn/rec_advantage.py b/trinity/algorithm/advantage_fn/rec_advantage.py new file mode 100644 index 0000000000..140e3975cb --- /dev/null +++ b/trinity/algorithm/advantage_fn/rec_advantage.py @@ -0,0 +1,100 @@ +"""REC advantage computation +""" + +from typing import Dict, List, Optional, Tuple + +import torch + +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, GroupAdvantage +from trinity.common.experience import Experience, group_by + + +@ADVANTAGE_FN.register_module("rec") +class RECGroupedAdvantage(GroupAdvantage): + """An advantage class that calculates REC advantages.""" + + def __init__( + self, + epsilon: float = 1e-6, + std_normalize: Optional[bool] = False, + drop: Optional[str] = None, + ) -> None: + """Initialize the REC advantage function. + + Args: + epsilon (float): A small value to avoid division by zero. + std_normalize (Optional[bool]): If provided, normalize the advantage with group-level reward standard deviation. + drop (Optional[str]): Strategy to drop experiences. Options are "balance" or None. + """ + self.epsilon = epsilon + self.std_normalize = std_normalize + self.drop = drop + assert self.drop in [None, "balance"], f"Invalid drop: {self.drop}" + + def group_experiences(self, exps): + return group_by(exps, id_type="task") + + def calculate_group_advantage( + self, group_id: str, exps: List[Experience] + ) -> Tuple[List[Experience], Dict]: + # Initialize masks and metrics + N = len(exps) + metrics = {} + with torch.no_grad(): + if len(exps) == 1: + group_reward_mean = torch.tensor(0.0) + group_reward_std = torch.tensor(1.0) # set to 1.0 to avoid division by zero + else: + rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) + group_reward_mean = torch.mean(rewards) + group_reward_std = torch.std(rewards) + + is_pos = rewards >= group_reward_mean + pos_count = is_pos.sum().item() + neg_count = len(exps) - pos_count + + drop_idx = torch.tensor([], dtype=torch.long) + drop_frac = 0.0 + if self.drop == "balance" and neg_count > pos_count: + extra_neg = neg_count - pos_count + neg_idx = (~is_pos).nonzero(as_tuple=True)[0] + perm = torch.randperm(len(neg_idx))[:extra_neg] + drop_idx = neg_idx[perm] + drop_frac = float(extra_neg) / float(max(N, 1)) + metrics["drop_balance"] = drop_frac + keep_mask = torch.ones(N, dtype=torch.bool) + if drop_idx.numel() > 0: + keep_mask[drop_idx] = False + + if keep_mask.sum().item() <= 1: + group_reward_mean = torch.tensor(0.0) + group_reward_std = torch.tensor(1.0) # avoid divide-by-zero + else: + sel_rewards = rewards[keep_mask] + group_reward_mean = sel_rewards.mean() + group_reward_std = sel_rewards.std(unbiased=False) + + for i, exp in enumerate(exps): + if not keep_mask[i]: + adv = torch.tensor(0.0) + else: + if getattr(self, "std_normalize", False): + adv = (rewards[i] - group_reward_mean) / (group_reward_std + self.epsilon) + else: + adv = rewards[i] - group_reward_mean + + exp.advantages = adv * exp.action_mask + exp.returns = exp.advantages.clone() + + metrics["reward_mean"] = group_reward_mean.item() + metrics["reward_std"] = group_reward_std.item() + + return exps, metrics + + @classmethod + def default_args(cls) -> dict: + return { + "epsilon": 1e-6, + "std_normalize": False, + "drop": None, + } diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index b17ba7abec..f9ae14e8c9 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -338,3 +338,26 @@ def default_config(cls) -> Dict: "kl_loss_fn": "none", "entropy_loss_fn": "none", } + + +@ALGORITHM_TYPE.register_module("rec") +class RECAlgorithm(AlgorithmType): + """REC Algorithm.""" + + use_critic: bool = False + use_reference: bool = True + compute_advantage_in_trainer: bool = False + can_balance_batch: bool = True + schema: str = "experience" + + @classmethod + def default_config(cls) -> Dict: + return { + "repeat_times": 2, + "sample_strategy": "warmup", + "policy_loss_fn": "rec", + "advantage_fn": "rec", + "kl_penalty_fn": "none", + "kl_loss_fn": "none", + "entropy_loss_fn": "none", + } diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index f66d7dbde4..4f7b70b917 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -10,6 +10,7 @@ from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.rec_policy_loss import RECPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn @@ -29,4 +30,5 @@ "SFTISLossFn", "SFTPhiLossFn", "sPPOPolicyLossFn", + "RECPolicyLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/rec_policy_loss.py b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py new file mode 100644 index 0000000000..ec0e623c9f --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py @@ -0,0 +1,132 @@ +"""REC-token policy loss function. +""" + +from typing import Dict, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.utils import masked_mean + + +@POLICY_LOSS_FN.register_module("rec") +class RECPolicyLossFn(PolicyLossFn): + def __init__( + self, + backend: str = "verl", + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + epsilon_low_prime: float = 0.4, + epsilon_high_prime: float = 0.4, + clip_mode: str = "none", + weight: str = "none", + regularizer: str = "none", + regularizer_coef: float = 0.0, + temp: float = 1.0, + ) -> None: + super().__init__(backend=backend) + + self.epsilon_low = epsilon_low + self.epsilon_high = epsilon_high + assert 0.0 < self.epsilon_low <= 1.0, f"Invalid epsilon_low: {self.epsilon_low}" + assert 0.0 < self.epsilon_high, f"Invalid epsilon_high: {self.epsilon_high}" + self.epsilon_low_prime = epsilon_low_prime + self.epsilon_high_prime = epsilon_high_prime + assert ( + 0.0 < self.epsilon_low_prime <= 1.0 + ), f"Invalid epsilon_low_prime: {self.epsilon_low_prime}" + assert ( + 0.0 < self.epsilon_high_prime + ), f"Invalid epsilon_high_prime: {self.epsilon_high_prime}" + self.clip_mode = clip_mode + assert self.clip_mode in [ + "none", + "one-side", + "two-side", + "ring", + ], f"Invalid clip_mode: {self.clip_mode}" + self.weight = weight + assert self.weight in [ + "none", + "importance_sampling", + "advantage", + ], f"Invalid weight: {self.weight}" + + self.regularizer = regularizer + assert self.regularizer in [ + "none", + "k2", + "forward-kl", + ], f"Invalid regularizer: {self.regularizer}" + self.regularizer_coef = regularizer_coef + assert self.regularizer_coef >= 0.0, f"Invalid regularizer_coef: {self.regularizer_coef}" + self.temp = temp + assert self.temp > 0.0, f"Invalid temp: {self.temp}" + + def __call__( # type: ignore + self, + logprob: torch.Tensor, # [batch_size, seq_len] + old_logprob: torch.Tensor, # [batch_size, seq_len] + action_mask: torch.Tensor, # [batch_size, seq_len] + advantages: torch.Tensor, # [batch_size, seq_len] + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + """Calculate REC loss.""" + # token-wise + ratio = torch.exp(logprob - old_logprob).detach() + + # clipping + if self.clip_mode == "two-side": + is_in_range = (ratio >= (1 - self.epsilon_low)) * (ratio <= (1 + self.epsilon_high)) + elif self.clip_mode == "one-side": + is_in_range = (ratio <= (1 + self.epsilon_high)) * (advantages >= 0) + ( + advantages <= 0 + ) * (ratio >= (1 - self.epsilon_low)) + elif self.clip_mode == "ring": + is_in_range = ( + (ratio >= (1 - self.epsilon_low)) * (ratio <= (1 + self.epsilon_high)) + + (advantages >= 0) * (ratio <= 1 - self.epsilon_low_prime) + + (advantages <= 0) * (ratio >= 1 + self.epsilon_high_prime) + ) + else: # none + is_in_range = torch.ones_like(ratio).bool() + is_clipped_mask = ~is_in_range + + if self.weight == "importance_sampling": + advantages = advantages * ratio # importance sampling + elif self.weight == "advantage": + weight = torch.exp(advantages / self.temp) + advantages = advantages * weight # advantage weighting (unnormalized version) + + pg_losses = -advantages * logprob * is_in_range.float() + + if self.regularizer == "forward-kl": + regularizer_losses = self.regularizer_coef * logprob + pg_losses = pg_losses - regularizer_losses + elif self.regularizer == "k2": + # note that here we absorb the 1/2 in Kimi into \tau + regularizer_losses = self.regularizer_coef * (logprob - old_logprob).square() + pg_losses = pg_losses + regularizer_losses + + pg_loss = masked_mean(pg_losses, action_mask) + + pg_clipfrac = masked_mean(is_clipped_mask.float(), action_mask) + metrics = { + "pg_clipfrac": pg_clipfrac.item(), + "pg_loss": pg_loss.detach().item(), + } + return pg_loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon_low": 0.2, + "epsilon_high": 0.2, + "epsilon_low_prime": 0.6, + "epsilon_high_prime": 2, + "clip_mode": "none", + "weight": "none", + "regularizer": "none", + "regularizer_coef": 0.0, + "temp": 1.0, + }