Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documenta

## Overview

TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Group Realtive Policy Optimization (GRPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.

## Highlights

Expand Down
4 changes: 2 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@
title: KTO
- local: orpo_trainer
title: ORPO
- local: ppo_trainer
title: PPO
- local: prm_trainer
title: PRM
- local: reward_trainer
Expand Down Expand Up @@ -119,6 +117,8 @@
title: Nash-MD
- local: papo_trainer
title: PAPO
- local: ppo_trainer
title: PPO
- local: xpo_trainer
title: XPO
- local: openenv
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling |
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
Expand Down
6 changes: 3 additions & 3 deletions docs/source/example_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ These notebooks are easier to run and are designed for quick experimentation wit

Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) and [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directories. They show how to use different trainers such as `SFTTrainer`, `PPOTrainer`, `DPOTrainer`, `GRPOTrainer`, and more.

File | Description |
| File | Description |
| --- | --- |
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty, and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`experimental.cpo.CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
Expand All @@ -55,8 +55,8 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train an Outcome Reward Model (ORM) on your own dataset. |
| [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
- [`GRPOTrainer`] ⚡️
- [`RLOOTrainer`] ⚡️
- [`OnlineDPOTrainer`] ⚡️
- [`PPOTrainer`]
- [`experimental.nash_md.NashMDTrainer`] 🧪 ⚡️
- [`experimental.ppo.PPOTrainer`] 🧪
- [`experimental.xpo.XPOTrainer`] 🧪 ⚡️

### Reward modeling
Expand Down
3 changes: 2 additions & 1 deletion docs/source/peft_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ After training your reward adapter and pushing it to the Hub:

```python
from peft import LoraConfig
from trl import AutoModelForCausalLMWithValueHead, PPOTrainer
from trl import AutoModelForCausalLMWithValueHead
from trl.experimental.ppo import PPOTrainer

model_name = "huggyllama/llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
Expand Down
10 changes: 8 additions & 2 deletions docs/source/ppo_trainer.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# PPO Trainer

<Tip warning={true}>

**Deprecation Notice**: PPOTrainer and PPOConfig have been moved to `trl.experimental.ppo` and will be removed from `trl.trainer` in TRL 0.29.0. Please update your imports to use `from trl.experimental.ppo import PPOConfig, PPOTrainer` instead. See [issue #4466](https://github.com/huggingface/trl/issues/4466) for more information.

</Tip>

[![model badge](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl)

TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347).
Expand Down Expand Up @@ -228,11 +234,11 @@ python -m openrlbenchmark.rlops_multi_metrics \

## PPOTrainer

[[autodoc]] PPOTrainer
[[autodoc]] experimental.ppo.PPOTrainer
- train
- save_model
- push_to_hub

## PPOConfig

[[autodoc]] PPOConfig
[[autodoc]] experimental.ppo.PPOConfig
2 changes: 1 addition & 1 deletion docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)
<hfoption id="PPO">

```python
from trl import PPOConfig
from trl.experimental.ppo import PPOConfig

training_args = PPOConfig(..., ds3_gather_for_generation=False)
```
Expand Down
11 changes: 2 additions & 9 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,8 @@
HfArgumentParser,
)

from trl import (
ModelConfig,
PPOConfig,
PPOTrainer,
ScriptArguments,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config
from trl.experimental.ppo import PPOConfig, PPOTrainer


# Enable logging in a Hugging Face Space
Expand Down
11 changes: 2 additions & 9 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,8 @@
HfArgumentParser,
)

from trl import (
ModelConfig,
PPOConfig,
PPOTrainer,
ScriptArguments,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config
from trl.experimental.ppo import PPOConfig, PPOTrainer


# Enable logging in a Hugging Face Space
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.utils import is_peft_available

from trl import PPOConfig, PPOTrainer
from trl.trainer.ppo_trainer import masked_mean, masked_var, masked_whiten
from trl.experimental.ppo import PPOConfig, PPOTrainer
from trl.experimental.ppo.ppo_trainer import masked_mean, masked_var, masked_whiten

from .testing_utils import TrlTestCase, require_peft
from ..testing_utils import TrlTestCase, require_peft


if is_peft_available():
Expand Down
19 changes: 19 additions & 0 deletions trl/experimental/ppo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer


__all__ = ["PPOConfig", "PPOTrainer"]
135 changes: 135 additions & 0 deletions trl/experimental/ppo/ppo_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from dataclasses import dataclass, field
from typing import Literal

from ...trainer.utils import OnPolicyConfig


@dataclass
class PPOConfig(OnPolicyConfig):
r"""
Configuration class for the [`experimental.ppo.PPOTrainer`].

This class includes only the parameters that are specific to PPO training. For a full list of training arguments,
please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
values in this class may differ from those in [`~transformers.TrainingArguments`].

Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.

Parameters:
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
Name of this experiment.
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
Path to the reward model.
model_adapter_name (`str`, *optional*):
Name of the train target PEFT adapter, when using LoRA with multiple adapters.
ref_adapter_name (`str`, *optional*):
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
num_ppo_epochs (`int`, *optional*, defaults to `4`):
Number of epochs to train.
whiten_rewards (`bool`, *optional*, defaults to `False`):
Whether to whiten the rewards.
kl_coef (`float`, *optional*, defaults to `0.05`):
KL coefficient.
kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`):
Which estimator for KL-Divergence to use from [Approximating KL
Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased
estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly
better estimator". Cannot be set to "k2", as it is used for logging purposes.
cliprange (`float`, *optional*, defaults to `0.2`):
Clip range.
vf_coef (`float`, *optional*, defaults to `0.1`):
Value function coefficient.
cliprange_value (`float`, *optional*, defaults to `0.2`):
Clip range for the value function.
gamma (`float`, *optional*, defaults to `1.0`):
Discount factor.
lam (`float`, *optional*, defaults to `0.95`):
Lambda value for GAE.
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
improving generation speed. However, disabling this option allows training models that exceed the VRAM
capacity of a single GPU, albeit at the cost of slower generation.
"""

exp_name: str = field(
default=os.path.basename(__file__)[:-3],
metadata={"help": "Name of this experiment."},
)
reward_model_path: str = field(
default="EleutherAI/pythia-160m",
metadata={"help": "Path to the reward model."},
)
model_adapter_name: str | None = field(
default=None,
metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."},
)
ref_adapter_name: str | None = field(
default=None,
metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."},
)
num_ppo_epochs: int = field(
default=4,
metadata={"help": "Number of epochs to train."},
)
whiten_rewards: bool = field(
default=False,
metadata={"help": "Whether to whiten the rewards."},
)
kl_coef: float = field(
default=0.05,
metadata={"help": "KL coefficient."},
)
kl_estimator: Literal["k1", "k3"] = field(
default="k1",
metadata={
"help": "Which estimator for KL-Divergence to use from Approximating KL Divergence "
"(http://joschu.net/blog/kl-approx.html). Defaults to 'k1', a straightforward, unbiased estimator. Can be "
"set to 'k3', an unbiased estimator with lower variance which 'appears to be a strictly better "
"estimator'. Cannot be set to 'k2', as it is used for logging purposes."
},
)
cliprange: float = field(
default=0.2,
metadata={"help": "Clip range."},
)
vf_coef: float = field(
default=0.1,
metadata={"help": "Value function coefficient."},
)
cliprange_value: float = field(
default=0.2,
metadata={"help": "Clip range for the value function."},
)
gamma: float = field(
default=1.0,
metadata={"help": "Discount factor."},
)
lam: float = field(
default=0.95,
metadata={"help": "Lambda value for GAE."},
)
ds3_gather_for_generation: bool = field(
default=True,
metadata={
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
"generation, improving generation speed. However, disabling this option allows training models that "
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation."
},
)
Loading
Loading