diff --git a/README.md b/README.md index 81497586da1..016ec2ff1eb 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documenta ## Overview -TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups. +TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Group Realtive Policy Optimization (GRPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups. ## Highlights diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 6fc8e2ea783..a4ca28675bc 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -66,8 +66,6 @@ title: KTO - local: orpo_trainer title: ORPO - - local: ppo_trainer - title: PPO - local: prm_trainer title: PRM - local: reward_trainer @@ -119,6 +117,8 @@ title: Nash-MD - local: papo_trainer title: PAPO + - local: ppo_trainer + title: PPO - local: xpo_trainer title: XPO - local: openenv diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index 3444742987f..8faf2f11fc7 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -396,7 +396,7 @@ Choosing the right dataset type depends on the task you are working on and the s | [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) | | [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) | | [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) | -| [`PPOTrainer`] | Tokenized language modeling | +| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling | | [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) | | [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) | | [`RLOOTrainer`] | [Prompt-only](#prompt-only) | diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index 2e1ea944187..8d4ee5b2400 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -37,7 +37,7 @@ These notebooks are easier to run and are designed for quick experimentation wit Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) and [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directories. They show how to use different trainers such as `SFTTrainer`, `PPOTrainer`, `DPOTrainer`, `GRPOTrainer`, and more. - File | Description | +| File | Description | | --- | --- | | [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty, and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. | | [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`experimental.cpo.CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | @@ -55,8 +55,8 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. | | [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. | | [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. | -| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. | -| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | +| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. | +| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. | | [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). | | [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train an Outcome Reward Model (ORM) on your own dataset. | | [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. | diff --git a/docs/source/index.md b/docs/source/index.md index 1ac0d5d7321..95f964b671a 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -25,8 +25,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL - [`GRPOTrainer`] ⚡️ - [`RLOOTrainer`] ⚡️ - [`OnlineDPOTrainer`] ⚡️ -- [`PPOTrainer`] - [`experimental.nash_md.NashMDTrainer`] 🧪 ⚡️ +- [`experimental.ppo.PPOTrainer`] 🧪 - [`experimental.xpo.XPOTrainer`] 🧪 ⚡️ ### Reward modeling diff --git a/docs/source/peft_integration.md b/docs/source/peft_integration.md index bd196dd99bf..221d9b7071b 100644 --- a/docs/source/peft_integration.md +++ b/docs/source/peft_integration.md @@ -146,7 +146,8 @@ After training your reward adapter and pushing it to the Hub: ```python from peft import LoraConfig -from trl import AutoModelForCausalLMWithValueHead, PPOTrainer +from trl import AutoModelForCausalLMWithValueHead +from trl.experimental.ppo import PPOTrainer model_name = "huggyllama/llama-7b" rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" diff --git a/docs/source/ppo_trainer.md b/docs/source/ppo_trainer.md index 1dabbc4177c..3f7ea2ee73f 100644 --- a/docs/source/ppo_trainer.md +++ b/docs/source/ppo_trainer.md @@ -1,5 +1,11 @@ # PPO Trainer + + +**Deprecation Notice**: PPOTrainer and PPOConfig have been moved to `trl.experimental.ppo` and will be removed from `trl.trainer` in TRL 0.29.0. Please update your imports to use `from trl.experimental.ppo import PPOConfig, PPOTrainer` instead. See [issue #4466](https://github.com/huggingface/trl/issues/4466) for more information. + + + [![model badge](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl) TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347). @@ -228,11 +234,11 @@ python -m openrlbenchmark.rlops_multi_metrics \ ## PPOTrainer -[[autodoc]] PPOTrainer +[[autodoc]] experimental.ppo.PPOTrainer - train - save_model - push_to_hub ## PPOConfig -[[autodoc]] PPOConfig +[[autodoc]] experimental.ppo.PPOConfig diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md index 4cc8626ccf2..f92ebb29edb 100644 --- a/docs/source/reducing_memory_usage.md +++ b/docs/source/reducing_memory_usage.md @@ -274,7 +274,7 @@ training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False) ```python -from trl import PPOConfig +from trl.experimental.ppo import PPOConfig training_args = PPOConfig(..., ds3_gather_for_generation=False) ``` diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 2f5471996c2..b77f30ad457 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -34,15 +34,8 @@ HfArgumentParser, ) -from trl import ( - ModelConfig, - PPOConfig, - PPOTrainer, - ScriptArguments, - get_kbit_device_map, - get_peft_config, - get_quantization_config, -) +from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config +from trl.experimental.ppo import PPOConfig, PPOTrainer # Enable logging in a Hugging Face Space diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 7962758ec40..bf4f487823b 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -34,15 +34,8 @@ HfArgumentParser, ) -from trl import ( - ModelConfig, - PPOConfig, - PPOTrainer, - ScriptArguments, - get_kbit_device_map, - get_peft_config, - get_quantization_config, -) +from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config +from trl.experimental.ppo import PPOConfig, PPOTrainer # Enable logging in a Hugging Face Space diff --git a/tests/test_ppo_trainer.py b/tests/experimental/test_ppo_trainer.py similarity index 97% rename from tests/test_ppo_trainer.py rename to tests/experimental/test_ppo_trainer.py index 78531316440..979d80e518b 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/experimental/test_ppo_trainer.py @@ -17,10 +17,10 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from transformers.utils import is_peft_available -from trl import PPOConfig, PPOTrainer -from trl.trainer.ppo_trainer import masked_mean, masked_var, masked_whiten +from trl.experimental.ppo import PPOConfig, PPOTrainer +from trl.experimental.ppo.ppo_trainer import masked_mean, masked_var, masked_whiten -from .testing_utils import TrlTestCase, require_peft +from ..testing_utils import TrlTestCase, require_peft if is_peft_available(): diff --git a/trl/experimental/ppo/__init__.py b/trl/experimental/ppo/__init__.py new file mode 100644 index 00000000000..6a58ea42975 --- /dev/null +++ b/trl/experimental/ppo/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .ppo_config import PPOConfig +from .ppo_trainer import PPOTrainer + + +__all__ = ["PPOConfig", "PPOTrainer"] diff --git a/trl/experimental/ppo/ppo_config.py b/trl/experimental/ppo/ppo_config.py new file mode 100644 index 00000000000..0d24617cff5 --- /dev/null +++ b/trl/experimental/ppo/ppo_config.py @@ -0,0 +1,135 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import Literal + +from ...trainer.utils import OnPolicyConfig + + +@dataclass +class PPOConfig(OnPolicyConfig): + r""" + Configuration class for the [`experimental.ppo.PPOTrainer`]. + + This class includes only the parameters that are specific to PPO training. For a full list of training arguments, + please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default + values in this class may differ from those in [`~transformers.TrainingArguments`]. + + Using [`~transformers.HfArgumentParser`] we can turn this class into + [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the + command line. + + Parameters: + exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`): + Name of this experiment. + reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): + Path to the reward model. + model_adapter_name (`str`, *optional*): + Name of the train target PEFT adapter, when using LoRA with multiple adapters. + ref_adapter_name (`str`, *optional*): + Name of the reference PEFT adapter, when using LoRA with multiple adapters. + num_ppo_epochs (`int`, *optional*, defaults to `4`): + Number of epochs to train. + whiten_rewards (`bool`, *optional*, defaults to `False`): + Whether to whiten the rewards. + kl_coef (`float`, *optional*, defaults to `0.05`): + KL coefficient. + kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`): + Which estimator for KL-Divergence to use from [Approximating KL + Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased + estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly + better estimator". Cannot be set to "k2", as it is used for logging purposes. + cliprange (`float`, *optional*, defaults to `0.2`): + Clip range. + vf_coef (`float`, *optional*, defaults to `0.1`): + Value function coefficient. + cliprange_value (`float`, *optional*, defaults to `0.2`): + Clip range for the value function. + gamma (`float`, *optional*, defaults to `1.0`): + Discount factor. + lam (`float`, *optional*, defaults to `0.95`): + Lambda value for GAE. + ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): + This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, + improving generation speed. However, disabling this option allows training models that exceed the VRAM + capacity of a single GPU, albeit at the cost of slower generation. + """ + + exp_name: str = field( + default=os.path.basename(__file__)[:-3], + metadata={"help": "Name of this experiment."}, + ) + reward_model_path: str = field( + default="EleutherAI/pythia-160m", + metadata={"help": "Path to the reward model."}, + ) + model_adapter_name: str | None = field( + default=None, + metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, + ) + ref_adapter_name: str | None = field( + default=None, + metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, + ) + num_ppo_epochs: int = field( + default=4, + metadata={"help": "Number of epochs to train."}, + ) + whiten_rewards: bool = field( + default=False, + metadata={"help": "Whether to whiten the rewards."}, + ) + kl_coef: float = field( + default=0.05, + metadata={"help": "KL coefficient."}, + ) + kl_estimator: Literal["k1", "k3"] = field( + default="k1", + metadata={ + "help": "Which estimator for KL-Divergence to use from Approximating KL Divergence " + "(http://joschu.net/blog/kl-approx.html). Defaults to 'k1', a straightforward, unbiased estimator. Can be " + "set to 'k3', an unbiased estimator with lower variance which 'appears to be a strictly better " + "estimator'. Cannot be set to 'k2', as it is used for logging purposes." + }, + ) + cliprange: float = field( + default=0.2, + metadata={"help": "Clip range."}, + ) + vf_coef: float = field( + default=0.1, + metadata={"help": "Value function coefficient."}, + ) + cliprange_value: float = field( + default=0.2, + metadata={"help": "Clip range for the value function."}, + ) + gamma: float = field( + default=1.0, + metadata={"help": "Discount factor."}, + ) + lam: float = field( + default=0.95, + metadata={"help": "Lambda value for GAE."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." + }, + ) diff --git a/trl/experimental/ppo/ppo_trainer.py b/trl/experimental/ppo/ppo_trainer.py new file mode 100644 index 00000000000..b11a245582f --- /dev/null +++ b/trl/experimental/ppo/ppo_trainer.py @@ -0,0 +1,836 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import math +import os +import textwrap +import time +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from accelerate import Accelerator, logging +from accelerate.utils import broadcast, gather_object +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + BaseImageProcessor, + DataCollatorWithPadding, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedTokenizerBase, + ProcessorMixin, + TrainerCallback, + TrainerControl, +) +from transformers.integrations import get_reporting_integration_callbacks +from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK +from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback +from transformers.utils import is_peft_available, is_rich_available + +from ...models import create_reference_model +from ...models.utils import unwrap_model_for_generation +from ...trainer.base_trainer import BaseTrainer +from ...trainer.utils import ( + OnlineTrainerState, + batch_generation, + disable_dropout_in_model, + empty_cache, + exact_div, + first_true_indices, + forward, + get_reward, + log_table_to_comet_experiment, + peft_module_casting_to_bf16, + prepare_deepspeed, + print_rich_table, + selective_log_softmax, + truncate_response, +) +from .ppo_config import PPOConfig + + +logger = logging.get_logger(__name__) + +if is_peft_available(): + from peft import PeftConfig, PeftModel, get_peft_model + + +INVALID_LOGPROB = 1.0 + + +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | None = None) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 +# we did this we can do a single `model = accelerator.prepare(model)` +class PolicyAndValueWrapper(nn.Module): + def __init__(self, policy, value_model) -> None: + super().__init__() + self.policy = policy + self.value_model = value_model + self.critic_backbone = getattr(value_model, value_model.base_model_prefix) + self.is_gradient_checkpointing = policy.is_gradient_checkpointing + + def forward(self, **kwargs): + output = self.critic_backbone(**kwargs) + logits = self.value_model.score(output.hidden_states[-1]) + return self.policy(**kwargs), logits + + +class PPOTrainer(BaseTrainer): + """Trainer for Proximal Policy Optimization (PPO). + + For details on PPO, see the paper: [Proximal Policy Optimization + Algorithms](https://huggingface.co/papers/1707.06347). + + Args: + args ([`experimental.ppo.PPOConfig`]): + Training arguments. + processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]): + Class to process the data. + model (`torch.nn.Module`): + Model to be trained. This is the policy model. + ref_model (`torch.nn.Module`, *optional*): + Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created. + reward_model (`torch.nn.Module`): + Reward model used to compute the rewards. + train_dataset ([`~datasets.Dataset`]): + Dataset for training. + value_model (`torch.nn.Module`): + Value model used to predict the value of a state. + data_collator ([`~transformers.DataCollatorWithPadding`], *optional*): + Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created + using the `processing_class`. + eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*): + Dataset for evaluation. + optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): + Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the + optimizer and the learning rate scheduler are created using the + [`~transformers.Trainer.create_optimizer_and_scheduler`] method. + callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): + Callbacks to use during training. + peft_config ([`~peft.PeftConfig`], *optional*): + PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model` + will be wrapped with the specified PEFT adapter. + """ + + _tag_names = ["trl", "ppo"] + _name = "PPO" + _paper = { + "title": "Fine-Tuning Language Models from Human Preferences", + "id": "1909.08593", + # docstyle-ignore + "citation": textwrap.dedent("""\ + @article{mziegler2019fine-tuning, + title = {{Fine-Tuning Language Models from Human Preferences}}, + author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, + year = 2019, + eprint = {arXiv:1909.08593} + }"""), + } + + def __init__( + self, + args: PPOConfig, + processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin, + model: nn.Module, + ref_model: nn.Module | None, + reward_model: nn.Module, + train_dataset: Dataset, + value_model: nn.Module, + data_collator: DataCollatorWithPadding | None = None, + eval_dataset: Dataset | dict[str, Dataset] | None = None, + # less commonly used + optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + callbacks: list[TrainerCallback] | None = None, + peft_config: "PeftConfig | None" = None, + ) -> None: + if ref_model is model: + raise ValueError( + "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " + "same as `model`, you must make a copy of it, or `None` if you use peft." + ) + + self.args = args + self.processing_class = processing_class + self.policy_model = model + + # Define the collator if not provided + if data_collator is None: + data_collator = DataCollatorWithPadding(self.processing_class) + + # Handle stop token settings: update policy model's generation_config to use provided stop token + if args.stop_token and args.stop_token_id: + raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") + elif args.stop_token: + if args.stop_token == "eos": + self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id + else: + raise ValueError( + f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." + ) + else: + self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int + + # Check that the kl estimator is valid + if self.args.kl_estimator not in {"k1", "k3"}: + raise ValueError( + "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, " + "appears to be a strictly better estimator). See " + "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details." + ) + + # peft support + if not is_peft_available() and peft_config is not None: + raise ImportError( + "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" + ) + elif is_peft_available() and peft_config is not None: + # if model is a peft model and we have a peft_confg, we merge and unload it first + if isinstance(self.policy_model, PeftModel): + self.policy_model = self.policy_model.merge_and_unload() + + # get peft model with the given config + self.policy_model = get_peft_model(self.policy_model, peft_config) + if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(self.policy_model) + + self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel) + self.model_adapter_name = args.model_adapter_name + self.ref_adapter_name = args.ref_adapter_name + + if ref_model: + self.ref_model = ref_model + elif self.is_peft_model: + self.ref_model = None + else: + self.ref_model = create_reference_model(self.policy_model) + + self.reward_model = reward_model + self.train_dataset = train_dataset + self.train_dataset_len = len(train_dataset) + self.value_model = value_model + self.data_collator = data_collator + self.eval_dataset = eval_dataset + self.optimizer, self.lr_scheduler = optimizers + self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47 + + ######### + # calculate various batch sizes + ######### + if args.total_episodes is None: # allow the users to define episodes in terms of epochs. + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) + accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + self.accelerator = accelerator + args.world_size = accelerator.num_processes + args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.batch_size = int(args.local_batch_size * args.world_size) + args.mini_batch_size = exact_div( + args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" + ) + args.local_mini_batch_size = exact_div( + args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" + ) + if args.whiten_rewards: + assert args.local_mini_batch_size >= 8, ( + f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + ) + # `per_rank_rollout_batch_size` is our `args.local_batch_size` + # `per_rank_minibatch_size` is our `args.local_mini_batch_size` + args.num_total_batches = math.ceil( + args.total_episodes / args.batch_size + ) # we may train for more than `total_episodes` + time_tensor = torch.tensor(int(time.time()), device=accelerator.device) + time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes + args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" + self.local_seed = args.seed + accelerator.process_index * 100003 # Prime + if args.num_sample_generations > 0: + self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) + self.local_dataloader_batch_size = args.local_batch_size + + ######### + # setup model, optimizer, and others + ######### + for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: + if module is not None: + disable_dropout_in_model(module) + self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) + self.model.config = self.policy_model.config # needed for pushing to hub + self.create_optimizer_and_scheduler( + num_training_steps=args.num_total_batches + ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level + + ######### + # trainer specifics + ######### + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + self.control = TrainerControl() + self.state = OnlineTrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + stateful_callbacks=[ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ], + ) + self.current_flos = 0 + self.hp_search_backend = None + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + # Create distant repo and output directory if needed + self.hub_model_id = None + if self.args.push_to_hub: + self.init_hf_repo() + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + + ######### + # setup dataloader + ######### + self.dataloader = DataLoader( + self.train_dataset, + batch_size=self.local_dataloader_batch_size, + shuffle=True, + collate_fn=self.data_collator, + drop_last=True, # needed; otherwise the last batch will be of ragged shape + ) + # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` + # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c + torch.manual_seed(args.seed) + self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) + torch.manual_seed(self.local_seed) # reset the local seed again + + self.eval_dataloader = DataLoader( + self.eval_dataset, + batch_size=args.per_device_eval_batch_size, + collate_fn=self.data_collator, + drop_last=True, + ) # no need to shuffle eval dataset + self.eval_dataloader = accelerator.prepare(self.eval_dataloader) + + if self.is_deepspeed_enabled: + self.reward_model = prepare_deepspeed( + self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + + if self.ref_model is None: + if not self.is_peft_model: + raise ValueError("No reference model and model is not a Peft model.") + else: + self.ref_model = prepare_deepspeed( + self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 + ) + else: + if self.ref_model is None: + if not self.is_peft_model: + raise ValueError("No reference model and model is not a Peft model.") + else: + self.ref_model = self.ref_model.to(self.accelerator.device) + self.reward_model = self.reward_model.to(self.accelerator.device) + + def get_train_dataloader(self) -> DataLoader: + return self.dataloader + + def get_eval_dataloader(self) -> DataLoader: + return self.eval_dataloader + + @contextmanager + def null_ref_context(self): + """Context manager for handling null reference model (that is, peft adapter manipulation).""" + with ( + self.accelerator.unwrap_model(self.model.policy).disable_adapter() + if self.is_peft_model and not self.ref_adapter_name + else nullcontext() + ): + if self.ref_adapter_name: + self.model.policy.set_adapter(self.ref_adapter_name) + yield + if self.ref_adapter_name: + self.model.policy.set_adapter(self.model_adapter_name or "default") + + def save_model(self, output_dir: str | None = None, _internal_call: bool = False): + backup_model = self.model + self.model = self.model.policy # save only the policy + + if self.is_deepspeed_enabled: + backup_deepspeed = self.deepspeed + self.deepspeed = self.model + + super().save_model(output_dir, _internal_call) + + self.model = backup_model + + if self.is_deepspeed_enabled: + self.deepspeed = backup_deepspeed + + def train(self): + args = self.args + accelerator = self.accelerator + optimizer = self.optimizer + model = self.model + ref_policy = self.ref_model + reward_model = self.reward_model + processing_class = self.processing_class + dataloader = self.dataloader + device = accelerator.device + + def repeat_generator(): + while True: + yield from dataloader + + iter_dataloader = iter(repeat_generator()) + generation_config = GenerationConfig( + max_new_tokens=args.response_length, + temperature=(args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + accelerator.print("===training policy===") + start_time = time.time() + stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + model.train() + + # trainer state initialization + self.state.global_step = 0 + self.state.episode = 0 + self.state.max_steps = args.num_total_batches + self.state.num_train_epochs = args.total_episodes / self.train_dataset_len + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model + self.model_wrapped = self.model + + for update in range(1, args.num_total_batches + 1): + self.state.episode += 1 * args.batch_size + data = next(iter_dataloader) + with torch.no_grad(): + queries = data["input_ids"].to(device) + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + sequence_lengths = [] + values = [] + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + query_responses, logitss = batch_generation( + unwrapped_model.policy, + queries, + args.local_rollout_forward_batch_size, + processing_class.pad_token_id, + generation_config, + ) + + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + logits = logitss[i : i + args.local_rollout_forward_batch_size] + logprob = selective_log_softmax(logits, response) + del logits + empty_cache() + + if ref_policy is None: + with self.null_ref_context(): + ref_output = forward(model.policy, query_response, processing_class.pad_token_id) + else: + ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_logprob = selective_log_softmax(ref_logits, response) + del ref_output, ref_logits + empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + self.stop_token_id, processing_class.pad_token_id, response + ) + + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 + unwrapped_value_model = accelerator.unwrap_model(model).value_model + full_value, _, _ = get_reward( + unwrapped_value_model, query_response, processing_class.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + _, score, _ = get_reward( + reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + values = torch.cat(values, 0) + del (logprob, ref_logprob, full_value, value, score, unwrapped_model) + empty_cache() + gc.collect() + + # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id + # Completions not passing that filter will receive a lower score. + contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) + if self.args.missing_eos_penalty is not None: + scores[~contain_eos_token] -= self.args.missing_eos_penalty + # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + + # 4. compute rewards + # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators + logr = ref_logprobs - logprobs + kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3 + non_score_reward = -args.kl_coef * kl + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[actual_start, actual_end] += scores + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + empty_cache() + + # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch + for ppo_epoch_idx in range(args.num_ppo_epochs): + b_inds = np.random.permutation(args.local_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + with accelerator.accumulate(model): + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + + output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= args.temperature + 1e-7 + new_logprobs = selective_log_softmax(logits, mb_responses) + new_logprobs = torch.masked_fill( + new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB + ) + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) + vf_clipfrac = masked_mean( + (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] + ) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + args.vf_coef * vf_loss + accelerator.backward(loss) + optimizer.step() + optimizer.zero_grad() + with torch.no_grad(): + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + prob_dist = torch.nn.functional.softmax(logits, dim=-1) + entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + pg_clipfrac + ) + pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( + vf_clipfrac + ) + entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # del everything and empty cache + # fmt: off + del ( + output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, + vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, + pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, + mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, + ) + # fmt: on + empty_cache() + with torch.no_grad(): + mean_kl = kl.sum(1).mean() + mean_entropy = (-logprobs).sum(1).mean() + mean_non_score_reward = non_score_reward.sum(1).mean() + rlhf_reward = mean_non_score_reward + scores.mean() + eps = int(self.state.episode / (time.time() - start_time)) + metrics = {} + metrics["eps"] = eps + metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() + metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() + metrics["objective/non_score_reward"] = ( + self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() + ) + metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() + metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() + metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() + metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() + metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() + metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item() + metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() + metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() + metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() + metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() + metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() + metrics["lr"] = self.lr_scheduler.get_last_lr()[0] + metrics["episode"] = self.state.episode + self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log + self.state.global_step += 1 + self.log(metrics) + + self.lr_scheduler.step() + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward + empty_cache() + gc.collect() + + if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: + self.generate_completions(sampling=True) + empty_cache() + del ( + query_responses, + responses, + postprocessed_responses, + logprobs, + ref_logprobs, + values, + sequence_lengths, + contain_eos_token, + sequence_lengths_p1, + response_idxs, + padding_mask, + padding_mask_p1, + rewards, + actual_start, + actual_end, + advantages, + returns, + ) + empty_cache() + + # HF trainer specifics + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + if self.control.should_save: + self._save_checkpoint(model, trial=None) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def generate_completions(self, sampling: bool = False): + args = self.args + processing_class = self.processing_class + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(0.01 + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + table = defaultdict(list) + with unwrap_model_for_generation( + self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + for batch in self.eval_dataloader: + query = batch["input_ids"] + with torch.no_grad(): + context_length = query.shape[1] + query_response, _ = batch_generation( + unwrapped_model.policy, + query, + query.shape[0], + processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + self.stop_token_id, processing_class.pad_token_id, response + ) + table["query"].extend( + gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) + ) + table["model response"].extend( + gather_object(processing_class.batch_decode(postprocessed_response)) + ) + + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length + ) + table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) + + if sampling: + break + df = pd.DataFrame(table) + + if self.accelerator.is_main_process: + if is_rich_available(): + print_rich_table(df.iloc[0 : 0 + 5]) + if "wandb" in args.report_to: + import wandb + + if wandb.run is not None: + wandb.log({"completions": wandb.Table(dataframe=df)}) + + if "comet_ml" in args.report_to: + log_table_to_comet_experiment( + name="completions.csv", + table=df, + ) + + # Ensure the model card is saved along with the checkpoint + def _save_checkpoint(self, model, trial): + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + self.create_model_card(model_name=model_name) + super()._save_checkpoint(model, trial) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 40d48b82dbf..e38cd190fe8 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -12,124 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from dataclasses import dataclass, field -from typing import Literal +import warnings +from dataclasses import dataclass -from ..trainer.utils import OnPolicyConfig +from ..experimental.ppo import PPOConfig as _PPOConfig @dataclass -class PPOConfig(OnPolicyConfig): - r""" - Configuration class for the [`PPOTrainer`]. - - This class includes only the parameters that are specific to PPO training. For a full list of training arguments, - please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default - values in this class may differ from those in [`~transformers.TrainingArguments`]. - - Using [`~transformers.HfArgumentParser`] we can turn this class into - [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the - command line. - - Parameters: - exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`): - Name of this experiment. - reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): - Path to the reward model. - model_adapter_name (`str`, *optional*): - Name of the train target PEFT adapter, when using LoRA with multiple adapters. - ref_adapter_name (`str`, *optional*): - Name of the reference PEFT adapter, when using LoRA with multiple adapters. - num_ppo_epochs (`int`, *optional*, defaults to `4`): - Number of epochs to train. - whiten_rewards (`bool`, *optional*, defaults to `False`): - Whether to whiten the rewards. - kl_coef (`float`, *optional*, defaults to `0.05`): - KL coefficient. - kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`): - Which estimator for KL-Divergence to use from [Approximating KL - Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased - estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly - better estimator". Cannot be set to "k2", as it is used for logging purposes. - cliprange (`float`, *optional*, defaults to `0.2`): - Clip range. - vf_coef (`float`, *optional*, defaults to `0.1`): - Value function coefficient. - cliprange_value (`float`, *optional*, defaults to `0.2`): - Clip range for the value function. - gamma (`float`, *optional*, defaults to `1.0`): - Discount factor. - lam (`float`, *optional*, defaults to `0.95`): - Lambda value for GAE. - ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): - This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, - improving generation speed. However, disabling this option allows training models that exceed the VRAM - capacity of a single GPU, albeit at the cost of slower generation. - """ - - exp_name: str = field( - default=os.path.basename(__file__)[:-3], - metadata={"help": "Name of this experiment."}, - ) - reward_model_path: str = field( - default="EleutherAI/pythia-160m", - metadata={"help": "Path to the reward model."}, - ) - model_adapter_name: str | None = field( - default=None, - metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, - ) - ref_adapter_name: str | None = field( - default=None, - metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, - ) - num_ppo_epochs: int = field( - default=4, - metadata={"help": "Number of epochs to train."}, - ) - whiten_rewards: bool = field( - default=False, - metadata={"help": "Whether to whiten the rewards."}, - ) - kl_coef: float = field( - default=0.05, - metadata={"help": "KL coefficient."}, - ) - kl_estimator: Literal["k1", "k3"] = field( - default="k1", - metadata={ - "help": "Which estimator for KL-Divergence to use from Approximating KL Divergence " - "(http://joschu.net/blog/kl-approx.html). Defaults to 'k1', a straightforward, unbiased estimator. Can be " - "set to 'k3', an unbiased estimator with lower variance which 'appears to be a strictly better " - "estimator'. Cannot be set to 'k2', as it is used for logging purposes." - }, - ) - cliprange: float = field( - default=0.2, - metadata={"help": "Clip range."}, - ) - vf_coef: float = field( - default=0.1, - metadata={"help": "Value function coefficient."}, - ) - cliprange_value: float = field( - default=0.2, - metadata={"help": "Clip range for the value function."}, - ) - gamma: float = field( - default=1.0, - metadata={"help": "Discount factor."}, - ) - lam: float = field( - default=0.95, - metadata={"help": "Lambda value for GAE."}, - ) - ds3_gather_for_generation: bool = field( - default=True, - metadata={ - "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " - "generation, improving generation speed. However, disabling this option allows training models that " - "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." - }, - ) +class PPOConfig(_PPOConfig): + def __post_init__(self): + warnings.warn( + "The `PPOConfig` is now located in `trl.experimental`. Please update your imports to " + "`from trl.experimental.ppo import PPOConfig`. The current import path will be removed and no longer " + "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223." + ) + super().__post_init__() diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index c4e5182d48f..fb2b94dfa37 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -12,833 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc -import math -import os -import textwrap -import time import warnings -from collections import defaultdict -from contextlib import contextmanager, nullcontext -from pathlib import Path +from dataclasses import dataclass -import numpy as np -import pandas as pd -import torch -import torch.nn as nn -from accelerate import Accelerator, logging -from accelerate.utils import broadcast, gather_object -from datasets import Dataset -from torch.utils.data import DataLoader -from transformers import ( - BaseImageProcessor, - DataCollatorWithPadding, - FeatureExtractionMixin, - GenerationConfig, - PreTrainedTokenizerBase, - ProcessorMixin, - TrainerCallback, - TrainerControl, -) -from transformers.integrations import get_reporting_integration_callbacks -from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK -from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback -from transformers.utils import is_peft_available, is_rich_available +from ..experimental.ppo import PPOTrainer as _PPOTrainer -from ..models import create_reference_model -from ..models.utils import unwrap_model_for_generation -from .base_trainer import BaseTrainer -from .ppo_config import PPOConfig -from .utils import ( - OnlineTrainerState, - batch_generation, - disable_dropout_in_model, - empty_cache, - exact_div, - first_true_indices, - forward, - get_reward, - log_table_to_comet_experiment, - peft_module_casting_to_bf16, - prepare_deepspeed, - print_rich_table, - selective_log_softmax, - truncate_response, -) - -logger = logging.get_logger(__name__) - -if is_peft_available(): - from peft import PeftConfig, PeftModel, get_peft_model - - -INVALID_LOGPROB = 1.0 - - -def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | None = None) -> torch.Tensor: - """Compute mean of tensor with a masked values.""" - if axis is not None: - return (values * mask).sum(axis=axis) / mask.sum(axis=axis) - else: - return (values * mask).sum() / mask.sum() - - -def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: - """Compute variance of tensor with masked values.""" - mean = masked_mean(values, mask) - centered_values = values - mean - variance = masked_mean(centered_values**2, mask) - if unbiased: - mask_sum = mask.sum() - if mask_sum == 0: - raise ValueError( - "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" - "try increase the `mini_batch_size` or `gradient_accumulation_steps`" - ) - # note that if mask_sum == 1, then there is a division by zero issue - # to avoid it you just need to use a larger minibatch_size - bessel_correction = mask_sum / (mask_sum - 1) - variance = variance * bessel_correction - return variance - - -def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: - """Whiten values with masked values.""" - mean, var = masked_mean(values, mask), masked_var(values, mask) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29 -# we did this we can do a single `model = accelerator.prepare(model)` -class PolicyAndValueWrapper(nn.Module): - def __init__(self, policy, value_model) -> None: - super().__init__() - self.policy = policy - self.value_model = value_model - self.critic_backbone = getattr(value_model, value_model.base_model_prefix) - self.is_gradient_checkpointing = policy.is_gradient_checkpointing - - def forward(self, **kwargs): - output = self.critic_backbone(**kwargs) - logits = self.value_model.score(output.hidden_states[-1]) - return self.policy(**kwargs), logits - - -class PPOTrainer(BaseTrainer): - """Trainer for Proximal Policy Optimization (PPO). - - For details on PPO, see the paper: [Proximal Policy Optimization - Algorithms](https://huggingface.co/papers/1707.06347). - - Args: - args ([`PPOConfig`]): - Training arguments. - processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]): - Class to process the data. - model (`torch.nn.Module`): - Model to be trained. This is the policy model. - ref_model (`torch.nn.Module`, *optional*): - Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created. - reward_model (`torch.nn.Module`): - Reward model used to compute the rewards. - train_dataset ([`~datasets.Dataset`]): - Dataset for training. - value_model (`torch.nn.Module`): - Value model used to predict the value of a state. - data_collator ([`~transformers.DataCollatorWithPadding`], *optional*): - Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created - using the `processing_class`. - eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*): - Dataset for evaluation. - optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): - Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the - optimizer and the learning rate scheduler are created using the - [`~transformers.Trainer.create_optimizer_and_scheduler`] method. - callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): - Callbacks to use during training. - peft_config ([`~peft.PeftConfig`], *optional*): - PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model` - will be wrapped with the specified PEFT adapter. - """ - - _tag_names = ["trl", "ppo"] - _name = "PPO" - _paper = { - "title": "Fine-Tuning Language Models from Human Preferences", - "id": "1909.08593", - # docstyle-ignore - "citation": textwrap.dedent("""\ - @article{mziegler2019fine-tuning, - title = {{Fine-Tuning Language Models from Human Preferences}}, - author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, - year = 2019, - eprint = {arXiv:1909.08593} - }"""), - } - - def __init__( - self, - args: PPOConfig, - processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin, - model: nn.Module, - ref_model: nn.Module | None, - reward_model: nn.Module, - train_dataset: Dataset, - value_model: nn.Module, - data_collator: DataCollatorWithPadding | None = None, - eval_dataset: Dataset | dict[str, Dataset] | None = None, - # less commonly used - optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), - callbacks: list[TrainerCallback] | None = None, - peft_config: "PeftConfig | None" = None, - ) -> None: - if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): - warnings.warn( - "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on " - "it and want it to remain, please share your comments here: " - "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable " - "TRL_EXPERIMENTAL_SILENCE=1." - ) - if ref_model is model: - raise ValueError( - "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the " - "same as `model`, you must make a copy of it, or `None` if you use peft." - ) - - self.args = args - self.processing_class = processing_class - self.policy_model = model - - # Define the collator if not provided - if data_collator is None: - data_collator = DataCollatorWithPadding(self.processing_class) - - # Handle stop token settings: update policy model's generation_config to use provided stop token - if args.stop_token and args.stop_token_id: - raise ValueError("You cannot set both `stop_token` and `stop_token_id`.") - elif args.stop_token: - if args.stop_token == "eos": - self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id - else: - raise ValueError( - f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)." - ) - else: - self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int - - # Check that the kl estimator is valid - if self.args.kl_estimator not in {"k1", "k3"}: - raise ValueError( - "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, " - "appears to be a strictly better estimator). See " - "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details." - ) - - # peft support - if not is_peft_available() and peft_config is not None: - raise ImportError( - "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models" - ) - elif is_peft_available() and peft_config is not None: - # if model is a peft model and we have a peft_confg, we merge and unload it first - if isinstance(self.policy_model, PeftModel): - self.policy_model = self.policy_model.merge_and_unload() - - # get peft model with the given config - self.policy_model = get_peft_model(self.policy_model, peft_config) - if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False): - peft_module_casting_to_bf16(self.policy_model) - - self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel) - self.model_adapter_name = args.model_adapter_name - self.ref_adapter_name = args.ref_adapter_name - - if ref_model: - self.ref_model = ref_model - elif self.is_peft_model: - self.ref_model = None - else: - self.ref_model = create_reference_model(self.policy_model) - - self.reward_model = reward_model - self.train_dataset = train_dataset - self.train_dataset_len = len(train_dataset) - self.value_model = value_model - self.data_collator = data_collator - self.eval_dataset = eval_dataset - self.optimizer, self.lr_scheduler = optimizers - self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47 - - ######### - # calculate various batch sizes - ######### - if args.total_episodes is None: # allow the users to define episodes in terms of epochs. - args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) - accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) - self.accelerator = accelerator - args.world_size = accelerator.num_processes - args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps - args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) - args.batch_size = int(args.local_batch_size * args.world_size) - args.mini_batch_size = exact_div( - args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`" - ) - args.local_mini_batch_size = exact_div( - args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" - ) - if args.whiten_rewards: - assert args.local_mini_batch_size >= 8, ( - f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" - ) - # `per_rank_rollout_batch_size` is our `args.local_batch_size` - # `per_rank_minibatch_size` is our `args.local_mini_batch_size` - args.num_total_batches = math.ceil( - args.total_episodes / args.batch_size - ) # we may train for more than `total_episodes` - time_tensor = torch.tensor(int(time.time()), device=accelerator.device) - time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes - args.run_name = f"{args.exp_name}__{args.seed}__{time_int}" - self.local_seed = args.seed + accelerator.process_index * 100003 # Prime - if args.num_sample_generations > 0: - self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations) - self.local_dataloader_batch_size = args.local_batch_size - - ######### - # setup model, optimizer, and others - ######### - for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: - if module is not None: - disable_dropout_in_model(module) - self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) - self.model.config = self.policy_model.config # needed for pushing to hub - self.create_optimizer_and_scheduler( - num_training_steps=args.num_total_batches - ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level - - ######### - # trainer specifics - ######### - default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) - self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks - self.callback_handler = CallbackHandler( - self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler - ) - self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) - self.control = TrainerControl() - self.state = OnlineTrainerState( - is_local_process_zero=self.is_local_process_zero(), - is_world_process_zero=self.is_world_process_zero(), - stateful_callbacks=[ - cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) - ], - ) - self.current_flos = 0 - self.hp_search_backend = None - self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None - self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None - # Create distant repo and output directory if needed - self.hub_model_id = None - if self.args.push_to_hub: - self.init_hf_repo() - if self.args.should_save: - os.makedirs(self.args.output_dir, exist_ok=True) - - # Add tags for models that have been loaded with the correct transformers version - if hasattr(self.model, "add_model_tags"): - self.model.add_model_tags(self._tag_names) - - ######### - # setup dataloader - ######### - self.dataloader = DataLoader( - self.train_dataset, - batch_size=self.local_dataloader_batch_size, - shuffle=True, - collate_fn=self.data_collator, - drop_last=True, # needed; otherwise the last batch will be of ragged shape - ) - # sync random states for DataLoader(shuffle=True) before `accelerator.prepare` - # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c - torch.manual_seed(args.seed) - self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader) - torch.manual_seed(self.local_seed) # reset the local seed again - - self.eval_dataloader = DataLoader( - self.eval_dataset, - batch_size=args.per_device_eval_batch_size, - collate_fn=self.data_collator, - drop_last=True, - ) # no need to shuffle eval dataset - self.eval_dataloader = accelerator.prepare(self.eval_dataloader) - - if self.is_deepspeed_enabled: - self.reward_model = prepare_deepspeed( - self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) - - if self.ref_model is None: - if not self.is_peft_model: - raise ValueError("No reference model and model is not a Peft model.") - else: - self.ref_model = prepare_deepspeed( - self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16 - ) - else: - if self.ref_model is None: - if not self.is_peft_model: - raise ValueError("No reference model and model is not a Peft model.") - else: - self.ref_model = self.ref_model.to(self.accelerator.device) - self.reward_model = self.reward_model.to(self.accelerator.device) - - def get_train_dataloader(self) -> DataLoader: - return self.dataloader - - def get_eval_dataloader(self) -> DataLoader: - return self.eval_dataloader - - @contextmanager - def null_ref_context(self): - """Context manager for handling null reference model (that is, peft adapter manipulation).""" - with ( - self.accelerator.unwrap_model(self.model.policy).disable_adapter() - if self.is_peft_model and not self.ref_adapter_name - else nullcontext() - ): - if self.ref_adapter_name: - self.model.policy.set_adapter(self.ref_adapter_name) - yield - if self.ref_adapter_name: - self.model.policy.set_adapter(self.model_adapter_name or "default") - - def save_model(self, output_dir: str | None = None, _internal_call: bool = False): - backup_model = self.model - self.model = self.model.policy # save only the policy - - if self.is_deepspeed_enabled: - backup_deepspeed = self.deepspeed - self.deepspeed = self.model - - super().save_model(output_dir, _internal_call) - - self.model = backup_model - - if self.is_deepspeed_enabled: - self.deepspeed = backup_deepspeed - - def train(self): - args = self.args - accelerator = self.accelerator - optimizer = self.optimizer - model = self.model - ref_policy = self.ref_model - reward_model = self.reward_model - processing_class = self.processing_class - dataloader = self.dataloader - device = accelerator.device - - def repeat_generator(): - while True: - yield from dataloader - - iter_dataloader = iter(repeat_generator()) - generation_config = GenerationConfig( - max_new_tokens=args.response_length, - temperature=(args.temperature + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, - ) - - accelerator.print("===training policy===") - start_time = time.time() - stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps) - approxkl_stats = torch.zeros(stats_shape, device=device) - pg_clipfrac_stats = torch.zeros(stats_shape, device=device) - pg_loss_stats = torch.zeros(stats_shape, device=device) - vf_loss_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropy_stats = torch.zeros(stats_shape, device=device) - ratio_stats = torch.zeros(stats_shape, device=device) - model.train() - - # trainer state initialization - self.state.global_step = 0 - self.state.episode = 0 - self.state.max_steps = args.num_total_batches - self.state.num_train_epochs = args.total_episodes / self.train_dataset_len - # Compute absolute values for logging, eval, and save if given as ratio - if args.logging_steps is not None: - if args.logging_steps < 1: - self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps) - else: - self.state.logging_steps = args.logging_steps - if args.eval_steps is not None: - if args.eval_steps < 1: - self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps) - else: - self.state.eval_steps = args.eval_steps - if args.save_steps is not None: - if args.save_steps < 1: - self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps) - else: - self.state.save_steps = args.save_steps - self.control = self.callback_handler.on_train_begin(args, self.state, self.control) - - # backward compatibility - if self.is_deepspeed_enabled: - self.deepspeed = self.model - self.model_wrapped = self.model - - for update in range(1, args.num_total_batches + 1): - self.state.episode += 1 * args.batch_size - data = next(iter_dataloader) - with torch.no_grad(): - queries = data["input_ids"].to(device) - context_length = queries.shape[1] - responses = [] - postprocessed_responses = [] - logprobs = [] - ref_logprobs = [] - scores = [] - sequence_lengths = [] - values = [] - with unwrap_model_for_generation( - self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model: - query_responses, logitss = batch_generation( - unwrapped_model.policy, - queries, - args.local_rollout_forward_batch_size, - processing_class.pad_token_id, - generation_config, - ) - - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response = query_responses[i : i + args.local_rollout_forward_batch_size] - response = query_response[:, context_length:] - logits = logitss[i : i + args.local_rollout_forward_batch_size] - logprob = selective_log_softmax(logits, response) - del logits - empty_cache() - - if ref_policy is None: - with self.null_ref_context(): - ref_output = forward(model.policy, query_response, processing_class.pad_token_id) - else: - ref_output = forward(ref_policy, query_response, processing_class.pad_token_id) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.temperature + 1e-7 - ref_logprob = selective_log_softmax(ref_logits, response) - del ref_output, ref_logits - empty_cache() - - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - self.stop_token_id, processing_class.pad_token_id, response - ) - - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1 - unwrapped_value_model = accelerator.unwrap_model(model).value_model - full_value, _, _ = get_reward( - unwrapped_value_model, query_response, processing_class.pad_token_id, context_length - ) - value = full_value[:, context_length - 1 : -1].squeeze(-1) - _, score, _ = get_reward( - reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) - - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - scores.append(score) - values.append(value) - responses = torch.cat(responses, 0) - postprocessed_responses = torch.cat(postprocessed_responses, 0) - logprobs = torch.cat(logprobs, 0) - ref_logprobs = torch.cat(ref_logprobs, 0) - sequence_lengths = torch.cat(sequence_lengths, 0) - scores = torch.cat(scores, 0) - values = torch.cat(values, 0) - del (logprob, ref_logprob, full_value, value, score, unwrapped_model) - empty_cache() - gc.collect() - - # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id - # Completions not passing that filter will receive a lower score. - contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1) - if self.args.missing_eos_penalty is not None: - scores[~contain_eos_token] -= self.args.missing_eos_penalty - # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}") - - # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) - padding_mask = response_idxs > sequence_lengths.unsqueeze(1) - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - sequence_lengths_p1 = sequence_lengths + 1 - padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) - values = torch.masked_fill(values, padding_mask_p1, 0) - - # 4. compute rewards - # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators - logr = ref_logprobs - logprobs - kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3 - non_score_reward = -args.kl_coef * kl - rewards = non_score_reward.clone() - actual_start = torch.arange(rewards.size(0), device=rewards.device) - actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) - rewards[actual_start, actual_end] += scores - - # 5. whiten rewards - if args.whiten_rewards: - rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) - rewards = torch.masked_fill(rewards, padding_mask_p1, 0) - - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = responses.shape[1] - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.gamma * args.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = masked_whiten(advantages, ~padding_mask) - advantages = torch.masked_fill(advantages, padding_mask, 0) - empty_cache() - - # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch - for ppo_epoch_idx in range(args.num_ppo_epochs): - b_inds = np.random.permutation(args.local_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): - with accelerator.accumulate(model): - micro_batch_end = micro_batch_start + args.per_device_train_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_advantage = advantages[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - mb_return = returns[micro_batch_inds] - mb_values = values[micro_batch_inds] - - output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id) - logits = output.logits[:, context_length - 1 : -1] - logits /= args.temperature + 1e-7 - new_logprobs = selective_log_softmax(logits, mb_responses) - new_logprobs = torch.masked_fill( - new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB - ) - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0) - vpredclipped = torch.clamp( - vpred, - mb_values - args.cliprange_value, - mb_values + args.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss_max = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds]) - vf_clipfrac = masked_mean( - (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds] - ) - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) - pg_loss_max = torch.max(pg_losses, pg_losses2) - pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) - loss = pg_loss + args.vf_coef * vf_loss - accelerator.backward(loss) - optimizer.step() - optimizer.zero_grad() - with torch.no_grad(): - pg_clipfrac = masked_mean( - (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] - ) - prob_dist = torch.nn.functional.softmax(logits, dim=-1) - entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - pg_clipfrac - ) - pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ( - vf_clipfrac - ) - entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - # del everything and empty cache - # fmt: off - del ( - output, vpred_temp, logits, new_logprobs, vpred, vpredclipped, - vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, - pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return, - mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs, - ) - # fmt: on - empty_cache() - with torch.no_grad(): - mean_kl = kl.sum(1).mean() - mean_entropy = (-logprobs).sum(1).mean() - mean_non_score_reward = non_score_reward.sum(1).mean() - rlhf_reward = mean_non_score_reward + scores.mean() - eps = int(self.state.episode / (time.time() - start_time)) - metrics = {} - metrics["eps"] = eps - metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item() - metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item() - metrics["objective/non_score_reward"] = ( - self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item() - ) - metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item() - metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item() - metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item() - metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item() - metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item() - metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item() - metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item() - metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item() - metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item() - metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item() - metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item() - metrics["lr"] = self.lr_scheduler.get_last_lr()[0] - metrics["episode"] = self.state.episode - self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log - self.state.global_step += 1 - self.log(metrics) - - self.lr_scheduler.step() - self.control = self.callback_handler.on_step_end(args, self.state, self.control) - if self.control.should_save: - self._save_checkpoint(model, trial=None) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward - empty_cache() - gc.collect() - - if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0: - self.generate_completions(sampling=True) - empty_cache() - del ( - query_responses, - responses, - postprocessed_responses, - logprobs, - ref_logprobs, - values, - sequence_lengths, - contain_eos_token, - sequence_lengths_p1, - response_idxs, - padding_mask, - padding_mask_p1, - rewards, - actual_start, - actual_end, - advantages, - returns, - ) - empty_cache() - - # HF trainer specifics - self.control = self.callback_handler.on_train_end(args, self.state, self.control) - if self.control.should_save: - self._save_checkpoint(model, trial=None) - self.control = self.callback_handler.on_save(self.args, self.state, self.control) - - def generate_completions(self, sampling: bool = False): - args = self.args - processing_class = self.processing_class - generation_config = GenerationConfig( - max_new_tokens=self.args.response_length, - temperature=(0.01 + 1e-7), - top_k=0.0, - top_p=1.0, - do_sample=True, +@dataclass +class PPOTrainer(_PPOTrainer): + def __init__(self, *args, **kwargs): + warnings.warn( + "The `PPOTrainer` is now located in `trl.experimental`. Please update your imports to " + "`from trl.experimental.ppo import PPOTrainer`. The current import path will be removed and no longer " + "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223." ) - - table = defaultdict(list) - with unwrap_model_for_generation( - self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model: - for batch in self.eval_dataloader: - query = batch["input_ids"] - with torch.no_grad(): - context_length = query.shape[1] - query_response, _ = batch_generation( - unwrapped_model.policy, - query, - query.shape[0], - processing_class.pad_token_id, - generation_config, - ) - response = query_response[:, context_length:] - postprocessed_response = response - if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - self.stop_token_id, processing_class.pad_token_id, response - ) - table["query"].extend( - gather_object(processing_class.batch_decode(query, skip_special_tokens=True)) - ) - table["model response"].extend( - gather_object(processing_class.batch_decode(postprocessed_response)) - ) - - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length - ) - table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy()) - - if sampling: - break - df = pd.DataFrame(table) - - if self.accelerator.is_main_process: - if is_rich_available(): - print_rich_table(df.iloc[0 : 0 + 5]) - if "wandb" in args.report_to: - import wandb - - if wandb.run is not None: - wandb.log({"completions": wandb.Table(dataframe=df)}) - - if "comet_ml" in args.report_to: - log_table_to_comet_experiment( - name="completions.csv", - table=df, - ) - - # Ensure the model card is saved along with the checkpoint - def _save_checkpoint(self, model, trial): - if self.args.hub_model_id is None: - model_name = Path(self.args.output_dir).name - else: - model_name = self.args.hub_model_id.split("/")[-1] - self.create_model_card(model_name=model_name) - super()._save_checkpoint(model, trial) + super().__init__(*args, **kwargs)