diff --git a/README.md b/README.md
index 81497586da1..016ec2ff1eb 100644
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ Explore how to seamlessly integrate TRL with OpenEnv in our [dedicated documenta
## Overview
-TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
+TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Group Realtive Policy Optimization (GRPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
## Highlights
diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml
index 6fc8e2ea783..a4ca28675bc 100644
--- a/docs/source/_toctree.yml
+++ b/docs/source/_toctree.yml
@@ -66,8 +66,6 @@
title: KTO
- local: orpo_trainer
title: ORPO
- - local: ppo_trainer
- title: PPO
- local: prm_trainer
title: PRM
- local: reward_trainer
@@ -119,6 +117,8 @@
title: Nash-MD
- local: papo_trainer
title: PAPO
+ - local: ppo_trainer
+ title: PPO
- local: xpo_trainer
title: XPO
- local: openenv
diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md
index 3444742987f..8faf2f11fc7 100644
--- a/docs/source/dataset_formats.md
+++ b/docs/source/dataset_formats.md
@@ -396,7 +396,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
-| [`PPOTrainer`] | Tokenized language modeling |
+| [`experimental.ppo.PPOTrainer`] | Tokenized language modeling |
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md
index 2e1ea944187..8d4ee5b2400 100644
--- a/docs/source/example_overview.md
+++ b/docs/source/example_overview.md
@@ -37,7 +37,7 @@ These notebooks are easier to run and are designed for quick experimentation wit
Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) and [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directories. They show how to use different trainers such as `SFTTrainer`, `PPOTrainer`, `DPOTrainer`, `GRPOTrainer`, and more.
- File | Description |
+| File | Description |
| --- | --- |
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty, and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`experimental.cpo.CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
@@ -55,8 +55,8 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
-| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
-| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
+| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
+| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`experimental.ppo.PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train an Outcome Reward Model (ORM) on your own dataset. |
| [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. |
diff --git a/docs/source/index.md b/docs/source/index.md
index 1ac0d5d7321..95f964b671a 100644
--- a/docs/source/index.md
+++ b/docs/source/index.md
@@ -25,8 +25,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
- [`GRPOTrainer`] ⚡️
- [`RLOOTrainer`] ⚡️
- [`OnlineDPOTrainer`] ⚡️
-- [`PPOTrainer`]
- [`experimental.nash_md.NashMDTrainer`] 🧪 ⚡️
+- [`experimental.ppo.PPOTrainer`] 🧪
- [`experimental.xpo.XPOTrainer`] 🧪 ⚡️
### Reward modeling
diff --git a/docs/source/peft_integration.md b/docs/source/peft_integration.md
index bd196dd99bf..221d9b7071b 100644
--- a/docs/source/peft_integration.md
+++ b/docs/source/peft_integration.md
@@ -146,7 +146,8 @@ After training your reward adapter and pushing it to the Hub:
```python
from peft import LoraConfig
-from trl import AutoModelForCausalLMWithValueHead, PPOTrainer
+from trl import AutoModelForCausalLMWithValueHead
+from trl.experimental.ppo import PPOTrainer
model_name = "huggyllama/llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
diff --git a/docs/source/ppo_trainer.md b/docs/source/ppo_trainer.md
index 1dabbc4177c..3f7ea2ee73f 100644
--- a/docs/source/ppo_trainer.md
+++ b/docs/source/ppo_trainer.md
@@ -1,5 +1,11 @@
# PPO Trainer
+
+
+**Deprecation Notice**: PPOTrainer and PPOConfig have been moved to `trl.experimental.ppo` and will be removed from `trl.trainer` in TRL 0.29.0. Please update your imports to use `from trl.experimental.ppo import PPOConfig, PPOTrainer` instead. See [issue #4466](https://github.com/huggingface/trl/issues/4466) for more information.
+
+
+
[](https://huggingface.co/models?other=ppo,trl)
TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347).
@@ -228,11 +234,11 @@ python -m openrlbenchmark.rlops_multi_metrics \
## PPOTrainer
-[[autodoc]] PPOTrainer
+[[autodoc]] experimental.ppo.PPOTrainer
- train
- save_model
- push_to_hub
## PPOConfig
-[[autodoc]] PPOConfig
+[[autodoc]] experimental.ppo.PPOConfig
diff --git a/docs/source/reducing_memory_usage.md b/docs/source/reducing_memory_usage.md
index 4cc8626ccf2..f92ebb29edb 100644
--- a/docs/source/reducing_memory_usage.md
+++ b/docs/source/reducing_memory_usage.md
@@ -274,7 +274,7 @@ training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)
```python
-from trl import PPOConfig
+from trl.experimental.ppo import PPOConfig
training_args = PPOConfig(..., ds3_gather_for_generation=False)
```
diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py
index 2f5471996c2..b77f30ad457 100644
--- a/examples/scripts/ppo/ppo.py
+++ b/examples/scripts/ppo/ppo.py
@@ -34,15 +34,8 @@
HfArgumentParser,
)
-from trl import (
- ModelConfig,
- PPOConfig,
- PPOTrainer,
- ScriptArguments,
- get_kbit_device_map,
- get_peft_config,
- get_quantization_config,
-)
+from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config
+from trl.experimental.ppo import PPOConfig, PPOTrainer
# Enable logging in a Hugging Face Space
diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py
index 7962758ec40..bf4f487823b 100644
--- a/examples/scripts/ppo/ppo_tldr.py
+++ b/examples/scripts/ppo/ppo_tldr.py
@@ -34,15 +34,8 @@
HfArgumentParser,
)
-from trl import (
- ModelConfig,
- PPOConfig,
- PPOTrainer,
- ScriptArguments,
- get_kbit_device_map,
- get_peft_config,
- get_quantization_config,
-)
+from trl import ModelConfig, ScriptArguments, get_kbit_device_map, get_peft_config, get_quantization_config
+from trl.experimental.ppo import PPOConfig, PPOTrainer
# Enable logging in a Hugging Face Space
diff --git a/tests/test_ppo_trainer.py b/tests/experimental/test_ppo_trainer.py
similarity index 97%
rename from tests/test_ppo_trainer.py
rename to tests/experimental/test_ppo_trainer.py
index 78531316440..979d80e518b 100644
--- a/tests/test_ppo_trainer.py
+++ b/tests/experimental/test_ppo_trainer.py
@@ -17,10 +17,10 @@
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.utils import is_peft_available
-from trl import PPOConfig, PPOTrainer
-from trl.trainer.ppo_trainer import masked_mean, masked_var, masked_whiten
+from trl.experimental.ppo import PPOConfig, PPOTrainer
+from trl.experimental.ppo.ppo_trainer import masked_mean, masked_var, masked_whiten
-from .testing_utils import TrlTestCase, require_peft
+from ..testing_utils import TrlTestCase, require_peft
if is_peft_available():
diff --git a/trl/experimental/ppo/__init__.py b/trl/experimental/ppo/__init__.py
new file mode 100644
index 00000000000..6a58ea42975
--- /dev/null
+++ b/trl/experimental/ppo/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .ppo_config import PPOConfig
+from .ppo_trainer import PPOTrainer
+
+
+__all__ = ["PPOConfig", "PPOTrainer"]
diff --git a/trl/experimental/ppo/ppo_config.py b/trl/experimental/ppo/ppo_config.py
new file mode 100644
index 00000000000..0d24617cff5
--- /dev/null
+++ b/trl/experimental/ppo/ppo_config.py
@@ -0,0 +1,135 @@
+# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from dataclasses import dataclass, field
+from typing import Literal
+
+from ...trainer.utils import OnPolicyConfig
+
+
+@dataclass
+class PPOConfig(OnPolicyConfig):
+ r"""
+ Configuration class for the [`experimental.ppo.PPOTrainer`].
+
+ This class includes only the parameters that are specific to PPO training. For a full list of training arguments,
+ please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
+ values in this class may differ from those in [`~transformers.TrainingArguments`].
+
+ Using [`~transformers.HfArgumentParser`] we can turn this class into
+ [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
+ command line.
+
+ Parameters:
+ exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
+ Name of this experiment.
+ reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
+ Path to the reward model.
+ model_adapter_name (`str`, *optional*):
+ Name of the train target PEFT adapter, when using LoRA with multiple adapters.
+ ref_adapter_name (`str`, *optional*):
+ Name of the reference PEFT adapter, when using LoRA with multiple adapters.
+ num_ppo_epochs (`int`, *optional*, defaults to `4`):
+ Number of epochs to train.
+ whiten_rewards (`bool`, *optional*, defaults to `False`):
+ Whether to whiten the rewards.
+ kl_coef (`float`, *optional*, defaults to `0.05`):
+ KL coefficient.
+ kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`):
+ Which estimator for KL-Divergence to use from [Approximating KL
+ Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased
+ estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly
+ better estimator". Cannot be set to "k2", as it is used for logging purposes.
+ cliprange (`float`, *optional*, defaults to `0.2`):
+ Clip range.
+ vf_coef (`float`, *optional*, defaults to `0.1`):
+ Value function coefficient.
+ cliprange_value (`float`, *optional*, defaults to `0.2`):
+ Clip range for the value function.
+ gamma (`float`, *optional*, defaults to `1.0`):
+ Discount factor.
+ lam (`float`, *optional*, defaults to `0.95`):
+ Lambda value for GAE.
+ ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
+ This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
+ improving generation speed. However, disabling this option allows training models that exceed the VRAM
+ capacity of a single GPU, albeit at the cost of slower generation.
+ """
+
+ exp_name: str = field(
+ default=os.path.basename(__file__)[:-3],
+ metadata={"help": "Name of this experiment."},
+ )
+ reward_model_path: str = field(
+ default="EleutherAI/pythia-160m",
+ metadata={"help": "Path to the reward model."},
+ )
+ model_adapter_name: str | None = field(
+ default=None,
+ metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."},
+ )
+ ref_adapter_name: str | None = field(
+ default=None,
+ metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."},
+ )
+ num_ppo_epochs: int = field(
+ default=4,
+ metadata={"help": "Number of epochs to train."},
+ )
+ whiten_rewards: bool = field(
+ default=False,
+ metadata={"help": "Whether to whiten the rewards."},
+ )
+ kl_coef: float = field(
+ default=0.05,
+ metadata={"help": "KL coefficient."},
+ )
+ kl_estimator: Literal["k1", "k3"] = field(
+ default="k1",
+ metadata={
+ "help": "Which estimator for KL-Divergence to use from Approximating KL Divergence "
+ "(http://joschu.net/blog/kl-approx.html). Defaults to 'k1', a straightforward, unbiased estimator. Can be "
+ "set to 'k3', an unbiased estimator with lower variance which 'appears to be a strictly better "
+ "estimator'. Cannot be set to 'k2', as it is used for logging purposes."
+ },
+ )
+ cliprange: float = field(
+ default=0.2,
+ metadata={"help": "Clip range."},
+ )
+ vf_coef: float = field(
+ default=0.1,
+ metadata={"help": "Value function coefficient."},
+ )
+ cliprange_value: float = field(
+ default=0.2,
+ metadata={"help": "Clip range for the value function."},
+ )
+ gamma: float = field(
+ default=1.0,
+ metadata={"help": "Discount factor."},
+ )
+ lam: float = field(
+ default=0.95,
+ metadata={"help": "Lambda value for GAE."},
+ )
+ ds3_gather_for_generation: bool = field(
+ default=True,
+ metadata={
+ "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
+ "generation, improving generation speed. However, disabling this option allows training models that "
+ "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation."
+ },
+ )
diff --git a/trl/experimental/ppo/ppo_trainer.py b/trl/experimental/ppo/ppo_trainer.py
new file mode 100644
index 00000000000..b11a245582f
--- /dev/null
+++ b/trl/experimental/ppo/ppo_trainer.py
@@ -0,0 +1,836 @@
+# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import math
+import os
+import textwrap
+import time
+from collections import defaultdict
+from contextlib import contextmanager, nullcontext
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+from accelerate import Accelerator, logging
+from accelerate.utils import broadcast, gather_object
+from datasets import Dataset
+from torch.utils.data import DataLoader
+from transformers import (
+ BaseImageProcessor,
+ DataCollatorWithPadding,
+ FeatureExtractionMixin,
+ GenerationConfig,
+ PreTrainedTokenizerBase,
+ ProcessorMixin,
+ TrainerCallback,
+ TrainerControl,
+)
+from transformers.integrations import get_reporting_integration_callbacks
+from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
+from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
+from transformers.utils import is_peft_available, is_rich_available
+
+from ...models import create_reference_model
+from ...models.utils import unwrap_model_for_generation
+from ...trainer.base_trainer import BaseTrainer
+from ...trainer.utils import (
+ OnlineTrainerState,
+ batch_generation,
+ disable_dropout_in_model,
+ empty_cache,
+ exact_div,
+ first_true_indices,
+ forward,
+ get_reward,
+ log_table_to_comet_experiment,
+ peft_module_casting_to_bf16,
+ prepare_deepspeed,
+ print_rich_table,
+ selective_log_softmax,
+ truncate_response,
+)
+from .ppo_config import PPOConfig
+
+
+logger = logging.get_logger(__name__)
+
+if is_peft_available():
+ from peft import PeftConfig, PeftModel, get_peft_model
+
+
+INVALID_LOGPROB = 1.0
+
+
+def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | None = None) -> torch.Tensor:
+ """Compute mean of tensor with a masked values."""
+ if axis is not None:
+ return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
+ else:
+ return (values * mask).sum() / mask.sum()
+
+
+def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
+ """Compute variance of tensor with masked values."""
+ mean = masked_mean(values, mask)
+ centered_values = values - mean
+ variance = masked_mean(centered_values**2, mask)
+ if unbiased:
+ mask_sum = mask.sum()
+ if mask_sum == 0:
+ raise ValueError(
+ "The sum of the mask is zero, which can happen when `mini_batch_size=1`;"
+ "try increase the `mini_batch_size` or `gradient_accumulation_steps`"
+ )
+ # note that if mask_sum == 1, then there is a division by zero issue
+ # to avoid it you just need to use a larger minibatch_size
+ bessel_correction = mask_sum / (mask_sum - 1)
+ variance = variance * bessel_correction
+ return variance
+
+
+def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
+ """Whiten values with masked values."""
+ mean, var = masked_mean(values, mask), masked_var(values, mask)
+ whitened = (values - mean) * torch.rsqrt(var + 1e-8)
+ if not shift_mean:
+ whitened += mean
+ return whitened
+
+
+# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29
+# we did this we can do a single `model = accelerator.prepare(model)`
+class PolicyAndValueWrapper(nn.Module):
+ def __init__(self, policy, value_model) -> None:
+ super().__init__()
+ self.policy = policy
+ self.value_model = value_model
+ self.critic_backbone = getattr(value_model, value_model.base_model_prefix)
+ self.is_gradient_checkpointing = policy.is_gradient_checkpointing
+
+ def forward(self, **kwargs):
+ output = self.critic_backbone(**kwargs)
+ logits = self.value_model.score(output.hidden_states[-1])
+ return self.policy(**kwargs), logits
+
+
+class PPOTrainer(BaseTrainer):
+ """Trainer for Proximal Policy Optimization (PPO).
+
+ For details on PPO, see the paper: [Proximal Policy Optimization
+ Algorithms](https://huggingface.co/papers/1707.06347).
+
+ Args:
+ args ([`experimental.ppo.PPOConfig`]):
+ Training arguments.
+ processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]):
+ Class to process the data.
+ model (`torch.nn.Module`):
+ Model to be trained. This is the policy model.
+ ref_model (`torch.nn.Module`, *optional*):
+ Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created.
+ reward_model (`torch.nn.Module`):
+ Reward model used to compute the rewards.
+ train_dataset ([`~datasets.Dataset`]):
+ Dataset for training.
+ value_model (`torch.nn.Module`):
+ Value model used to predict the value of a state.
+ data_collator ([`~transformers.DataCollatorWithPadding`], *optional*):
+ Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created
+ using the `processing_class`.
+ eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
+ Dataset for evaluation.
+ optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
+ Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the
+ optimizer and the learning rate scheduler are created using the
+ [`~transformers.Trainer.create_optimizer_and_scheduler`] method.
+ callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
+ Callbacks to use during training.
+ peft_config ([`~peft.PeftConfig`], *optional*):
+ PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model`
+ will be wrapped with the specified PEFT adapter.
+ """
+
+ _tag_names = ["trl", "ppo"]
+ _name = "PPO"
+ _paper = {
+ "title": "Fine-Tuning Language Models from Human Preferences",
+ "id": "1909.08593",
+ # docstyle-ignore
+ "citation": textwrap.dedent("""\
+ @article{mziegler2019fine-tuning,
+ title = {{Fine-Tuning Language Models from Human Preferences}},
+ author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
+ year = 2019,
+ eprint = {arXiv:1909.08593}
+ }"""),
+ }
+
+ def __init__(
+ self,
+ args: PPOConfig,
+ processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin,
+ model: nn.Module,
+ ref_model: nn.Module | None,
+ reward_model: nn.Module,
+ train_dataset: Dataset,
+ value_model: nn.Module,
+ data_collator: DataCollatorWithPadding | None = None,
+ eval_dataset: Dataset | dict[str, Dataset] | None = None,
+ # less commonly used
+ optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+ callbacks: list[TrainerCallback] | None = None,
+ peft_config: "PeftConfig | None" = None,
+ ) -> None:
+ if ref_model is model:
+ raise ValueError(
+ "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
+ "same as `model`, you must make a copy of it, or `None` if you use peft."
+ )
+
+ self.args = args
+ self.processing_class = processing_class
+ self.policy_model = model
+
+ # Define the collator if not provided
+ if data_collator is None:
+ data_collator = DataCollatorWithPadding(self.processing_class)
+
+ # Handle stop token settings: update policy model's generation_config to use provided stop token
+ if args.stop_token and args.stop_token_id:
+ raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
+ elif args.stop_token:
+ if args.stop_token == "eos":
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
+ else:
+ raise ValueError(
+ f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
+ )
+ else:
+ self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
+
+ # Check that the kl estimator is valid
+ if self.args.kl_estimator not in {"k1", "k3"}:
+ raise ValueError(
+ "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, "
+ "appears to be a strictly better estimator). See "
+ "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
+ )
+
+ # peft support
+ if not is_peft_available() and peft_config is not None:
+ raise ImportError(
+ "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
+ )
+ elif is_peft_available() and peft_config is not None:
+ # if model is a peft model and we have a peft_confg, we merge and unload it first
+ if isinstance(self.policy_model, PeftModel):
+ self.policy_model = self.policy_model.merge_and_unload()
+
+ # get peft model with the given config
+ self.policy_model = get_peft_model(self.policy_model, peft_config)
+ if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
+ peft_module_casting_to_bf16(self.policy_model)
+
+ self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
+ self.model_adapter_name = args.model_adapter_name
+ self.ref_adapter_name = args.ref_adapter_name
+
+ if ref_model:
+ self.ref_model = ref_model
+ elif self.is_peft_model:
+ self.ref_model = None
+ else:
+ self.ref_model = create_reference_model(self.policy_model)
+
+ self.reward_model = reward_model
+ self.train_dataset = train_dataset
+ self.train_dataset_len = len(train_dataset)
+ self.value_model = value_model
+ self.data_collator = data_collator
+ self.eval_dataset = eval_dataset
+ self.optimizer, self.lr_scheduler = optimizers
+ self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
+
+ #########
+ # calculate various batch sizes
+ #########
+ if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
+ args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
+ self.accelerator = accelerator
+ args.world_size = accelerator.num_processes
+ args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
+ args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
+ args.batch_size = int(args.local_batch_size * args.world_size)
+ args.mini_batch_size = exact_div(
+ args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
+ )
+ args.local_mini_batch_size = exact_div(
+ args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
+ )
+ if args.whiten_rewards:
+ assert args.local_mini_batch_size >= 8, (
+ f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
+ )
+ # `per_rank_rollout_batch_size` is our `args.local_batch_size`
+ # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
+ args.num_total_batches = math.ceil(
+ args.total_episodes / args.batch_size
+ ) # we may train for more than `total_episodes`
+ time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
+ time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
+ args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
+ self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
+ if args.num_sample_generations > 0:
+ self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
+ self.local_dataloader_batch_size = args.local_batch_size
+
+ #########
+ # setup model, optimizer, and others
+ #########
+ for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
+ if module is not None:
+ disable_dropout_in_model(module)
+ self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
+ self.model.config = self.policy_model.config # needed for pushing to hub
+ self.create_optimizer_and_scheduler(
+ num_training_steps=args.num_total_batches
+ ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
+
+ #########
+ # trainer specifics
+ #########
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
+ self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
+ self.callback_handler = CallbackHandler(
+ self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
+ )
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
+ self.control = TrainerControl()
+ self.state = OnlineTrainerState(
+ is_local_process_zero=self.is_local_process_zero(),
+ is_world_process_zero=self.is_world_process_zero(),
+ stateful_callbacks=[
+ cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
+ ],
+ )
+ self.current_flos = 0
+ self.hp_search_backend = None
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
+ # Create distant repo and output directory if needed
+ self.hub_model_id = None
+ if self.args.push_to_hub:
+ self.init_hf_repo()
+ if self.args.should_save:
+ os.makedirs(self.args.output_dir, exist_ok=True)
+
+ # Add tags for models that have been loaded with the correct transformers version
+ if hasattr(self.model, "add_model_tags"):
+ self.model.add_model_tags(self._tag_names)
+
+ #########
+ # setup dataloader
+ #########
+ self.dataloader = DataLoader(
+ self.train_dataset,
+ batch_size=self.local_dataloader_batch_size,
+ shuffle=True,
+ collate_fn=self.data_collator,
+ drop_last=True, # needed; otherwise the last batch will be of ragged shape
+ )
+ # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
+ # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
+ torch.manual_seed(args.seed)
+ self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
+ torch.manual_seed(self.local_seed) # reset the local seed again
+
+ self.eval_dataloader = DataLoader(
+ self.eval_dataset,
+ batch_size=args.per_device_eval_batch_size,
+ collate_fn=self.data_collator,
+ drop_last=True,
+ ) # no need to shuffle eval dataset
+ self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
+
+ if self.is_deepspeed_enabled:
+ self.reward_model = prepare_deepspeed(
+ self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
+ )
+
+ if self.ref_model is None:
+ if not self.is_peft_model:
+ raise ValueError("No reference model and model is not a Peft model.")
+ else:
+ self.ref_model = prepare_deepspeed(
+ self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
+ )
+ else:
+ if self.ref_model is None:
+ if not self.is_peft_model:
+ raise ValueError("No reference model and model is not a Peft model.")
+ else:
+ self.ref_model = self.ref_model.to(self.accelerator.device)
+ self.reward_model = self.reward_model.to(self.accelerator.device)
+
+ def get_train_dataloader(self) -> DataLoader:
+ return self.dataloader
+
+ def get_eval_dataloader(self) -> DataLoader:
+ return self.eval_dataloader
+
+ @contextmanager
+ def null_ref_context(self):
+ """Context manager for handling null reference model (that is, peft adapter manipulation)."""
+ with (
+ self.accelerator.unwrap_model(self.model.policy).disable_adapter()
+ if self.is_peft_model and not self.ref_adapter_name
+ else nullcontext()
+ ):
+ if self.ref_adapter_name:
+ self.model.policy.set_adapter(self.ref_adapter_name)
+ yield
+ if self.ref_adapter_name:
+ self.model.policy.set_adapter(self.model_adapter_name or "default")
+
+ def save_model(self, output_dir: str | None = None, _internal_call: bool = False):
+ backup_model = self.model
+ self.model = self.model.policy # save only the policy
+
+ if self.is_deepspeed_enabled:
+ backup_deepspeed = self.deepspeed
+ self.deepspeed = self.model
+
+ super().save_model(output_dir, _internal_call)
+
+ self.model = backup_model
+
+ if self.is_deepspeed_enabled:
+ self.deepspeed = backup_deepspeed
+
+ def train(self):
+ args = self.args
+ accelerator = self.accelerator
+ optimizer = self.optimizer
+ model = self.model
+ ref_policy = self.ref_model
+ reward_model = self.reward_model
+ processing_class = self.processing_class
+ dataloader = self.dataloader
+ device = accelerator.device
+
+ def repeat_generator():
+ while True:
+ yield from dataloader
+
+ iter_dataloader = iter(repeat_generator())
+ generation_config = GenerationConfig(
+ max_new_tokens=args.response_length,
+ temperature=(args.temperature + 1e-7),
+ top_k=0.0,
+ top_p=1.0,
+ do_sample=True,
+ )
+
+ accelerator.print("===training policy===")
+ start_time = time.time()
+ stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
+ approxkl_stats = torch.zeros(stats_shape, device=device)
+ pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
+ pg_loss_stats = torch.zeros(stats_shape, device=device)
+ vf_loss_stats = torch.zeros(stats_shape, device=device)
+ vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
+ entropy_stats = torch.zeros(stats_shape, device=device)
+ ratio_stats = torch.zeros(stats_shape, device=device)
+ model.train()
+
+ # trainer state initialization
+ self.state.global_step = 0
+ self.state.episode = 0
+ self.state.max_steps = args.num_total_batches
+ self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
+ # Compute absolute values for logging, eval, and save if given as ratio
+ if args.logging_steps is not None:
+ if args.logging_steps < 1:
+ self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
+ else:
+ self.state.logging_steps = args.logging_steps
+ if args.eval_steps is not None:
+ if args.eval_steps < 1:
+ self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
+ else:
+ self.state.eval_steps = args.eval_steps
+ if args.save_steps is not None:
+ if args.save_steps < 1:
+ self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
+ else:
+ self.state.save_steps = args.save_steps
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+ # backward compatibility
+ if self.is_deepspeed_enabled:
+ self.deepspeed = self.model
+ self.model_wrapped = self.model
+
+ for update in range(1, args.num_total_batches + 1):
+ self.state.episode += 1 * args.batch_size
+ data = next(iter_dataloader)
+ with torch.no_grad():
+ queries = data["input_ids"].to(device)
+ context_length = queries.shape[1]
+ responses = []
+ postprocessed_responses = []
+ logprobs = []
+ ref_logprobs = []
+ scores = []
+ sequence_lengths = []
+ values = []
+ with unwrap_model_for_generation(
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
+ ) as unwrapped_model:
+ query_responses, logitss = batch_generation(
+ unwrapped_model.policy,
+ queries,
+ args.local_rollout_forward_batch_size,
+ processing_class.pad_token_id,
+ generation_config,
+ )
+
+ for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
+ query = queries[i : i + args.local_rollout_forward_batch_size]
+ query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
+ response = query_response[:, context_length:]
+ logits = logitss[i : i + args.local_rollout_forward_batch_size]
+ logprob = selective_log_softmax(logits, response)
+ del logits
+ empty_cache()
+
+ if ref_policy is None:
+ with self.null_ref_context():
+ ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
+ else:
+ ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
+ ref_logits = ref_output.logits[:, context_length - 1 : -1]
+ ref_logits /= args.temperature + 1e-7
+ ref_logprob = selective_log_softmax(ref_logits, response)
+ del ref_output, ref_logits
+ empty_cache()
+
+ # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
+ postprocessed_response = response
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
+ postprocessed_response = truncate_response(
+ self.stop_token_id, processing_class.pad_token_id, response
+ )
+
+ # Response Processing 2. run reward model on the truncated responses
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
+ sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
+ unwrapped_value_model = accelerator.unwrap_model(model).value_model
+ full_value, _, _ = get_reward(
+ unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
+ )
+ value = full_value[:, context_length - 1 : -1].squeeze(-1)
+ _, score, _ = get_reward(
+ reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
+ )
+
+ responses.append(response)
+ postprocessed_responses.append(postprocessed_response)
+ logprobs.append(logprob)
+ ref_logprobs.append(ref_logprob)
+ sequence_lengths.append(sequence_length)
+ scores.append(score)
+ values.append(value)
+ responses = torch.cat(responses, 0)
+ postprocessed_responses = torch.cat(postprocessed_responses, 0)
+ logprobs = torch.cat(logprobs, 0)
+ ref_logprobs = torch.cat(ref_logprobs, 0)
+ sequence_lengths = torch.cat(sequence_lengths, 0)
+ scores = torch.cat(scores, 0)
+ values = torch.cat(values, 0)
+ del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
+ empty_cache()
+ gc.collect()
+
+ # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
+ # Completions not passing that filter will receive a lower score.
+ contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
+ if self.args.missing_eos_penalty is not None:
+ scores[~contain_eos_token] -= self.args.missing_eos_penalty
+ # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
+
+ # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
+ response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
+ padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
+ logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
+ ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
+ sequence_lengths_p1 = sequence_lengths + 1
+ padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
+ values = torch.masked_fill(values, padding_mask_p1, 0)
+
+ # 4. compute rewards
+ # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
+ logr = ref_logprobs - logprobs
+ kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3
+ non_score_reward = -args.kl_coef * kl
+ rewards = non_score_reward.clone()
+ actual_start = torch.arange(rewards.size(0), device=rewards.device)
+ actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
+ rewards[actual_start, actual_end] += scores
+
+ # 5. whiten rewards
+ if args.whiten_rewards:
+ rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
+ rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
+
+ # 6. compute advantages and returns
+ lastgaelam = 0
+ advantages_reversed = []
+ gen_length = responses.shape[1]
+ for t in reversed(range(gen_length)):
+ nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
+ delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
+ lastgaelam = delta + args.gamma * args.lam * lastgaelam
+ advantages_reversed.append(lastgaelam)
+ advantages = torch.stack(advantages_reversed[::-1], axis=1)
+ returns = advantages + values
+ advantages = masked_whiten(advantages, ~padding_mask)
+ advantages = torch.masked_fill(advantages, padding_mask, 0)
+ empty_cache()
+
+ # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
+ for ppo_epoch_idx in range(args.num_ppo_epochs):
+ b_inds = np.random.permutation(args.local_batch_size)
+ minibatch_idx = 0
+ for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
+ mini_batch_end = mini_batch_start + args.local_mini_batch_size
+ mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
+ gradient_accumulation_idx = 0
+ for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
+ with accelerator.accumulate(model):
+ micro_batch_end = micro_batch_start + args.per_device_train_batch_size
+ micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
+ mb_advantage = advantages[micro_batch_inds]
+ mb_responses = responses[micro_batch_inds]
+ mb_query_responses = query_responses[micro_batch_inds]
+ mb_logprobs = logprobs[micro_batch_inds]
+ mb_return = returns[micro_batch_inds]
+ mb_values = values[micro_batch_inds]
+
+ output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
+ logits = output.logits[:, context_length - 1 : -1]
+ logits /= args.temperature + 1e-7
+ new_logprobs = selective_log_softmax(logits, mb_responses)
+ new_logprobs = torch.masked_fill(
+ new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
+ )
+ vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
+ vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
+ vpredclipped = torch.clamp(
+ vpred,
+ mb_values - args.cliprange_value,
+ mb_values + args.cliprange_value,
+ )
+ vf_losses1 = torch.square(vpred - mb_return)
+ vf_losses2 = torch.square(vpredclipped - mb_return)
+ vf_loss_max = torch.max(vf_losses1, vf_losses2)
+ vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
+ vf_clipfrac = masked_mean(
+ (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
+ )
+ logprobs_diff = new_logprobs - mb_logprobs
+ ratio = torch.exp(logprobs_diff)
+ pg_losses = -mb_advantage * ratio
+ pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
+ pg_loss_max = torch.max(pg_losses, pg_losses2)
+ pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
+ loss = pg_loss + args.vf_coef * vf_loss
+ accelerator.backward(loss)
+ optimizer.step()
+ optimizer.zero_grad()
+ with torch.no_grad():
+ pg_clipfrac = masked_mean(
+ (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
+ )
+ prob_dist = torch.nn.functional.softmax(logits, dim=-1)
+ entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
+ approxkl = 0.5 * (logprobs_diff**2).mean()
+ approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
+ pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
+ pg_clipfrac
+ )
+ pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
+ vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
+ vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
+ vf_clipfrac
+ )
+ entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
+ ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
+ gradient_accumulation_idx += 1
+ minibatch_idx += 1
+ # del everything and empty cache
+ # fmt: off
+ del (
+ output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
+ vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
+ pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
+ mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
+ )
+ # fmt: on
+ empty_cache()
+ with torch.no_grad():
+ mean_kl = kl.sum(1).mean()
+ mean_entropy = (-logprobs).sum(1).mean()
+ mean_non_score_reward = non_score_reward.sum(1).mean()
+ rlhf_reward = mean_non_score_reward + scores.mean()
+ eps = int(self.state.episode / (time.time() - start_time))
+ metrics = {}
+ metrics["eps"] = eps
+ metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
+ metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
+ metrics["objective/non_score_reward"] = (
+ self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
+ )
+ metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
+ metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
+ metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
+ metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
+ metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
+ metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
+ metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
+ metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
+ metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
+ metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
+ metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
+ metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
+ metrics["episode"] = self.state.episode
+ self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
+ self.state.global_step += 1
+ self.log(metrics)
+
+ self.lr_scheduler.step()
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+ if self.control.should_save:
+ self._save_checkpoint(model, trial=None)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+ del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
+ empty_cache()
+ gc.collect()
+
+ if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
+ self.generate_completions(sampling=True)
+ empty_cache()
+ del (
+ query_responses,
+ responses,
+ postprocessed_responses,
+ logprobs,
+ ref_logprobs,
+ values,
+ sequence_lengths,
+ contain_eos_token,
+ sequence_lengths_p1,
+ response_idxs,
+ padding_mask,
+ padding_mask_p1,
+ rewards,
+ actual_start,
+ actual_end,
+ advantages,
+ returns,
+ )
+ empty_cache()
+
+ # HF trainer specifics
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+ if self.control.should_save:
+ self._save_checkpoint(model, trial=None)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+ def generate_completions(self, sampling: bool = False):
+ args = self.args
+ processing_class = self.processing_class
+ generation_config = GenerationConfig(
+ max_new_tokens=self.args.response_length,
+ temperature=(0.01 + 1e-7),
+ top_k=0.0,
+ top_p=1.0,
+ do_sample=True,
+ )
+
+ table = defaultdict(list)
+ with unwrap_model_for_generation(
+ self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
+ ) as unwrapped_model:
+ for batch in self.eval_dataloader:
+ query = batch["input_ids"]
+ with torch.no_grad():
+ context_length = query.shape[1]
+ query_response, _ = batch_generation(
+ unwrapped_model.policy,
+ query,
+ query.shape[0],
+ processing_class.pad_token_id,
+ generation_config,
+ )
+ response = query_response[:, context_length:]
+ postprocessed_response = response
+ if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
+ postprocessed_response = truncate_response(
+ self.stop_token_id, processing_class.pad_token_id, response
+ )
+ table["query"].extend(
+ gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
+ )
+ table["model response"].extend(
+ gather_object(processing_class.batch_decode(postprocessed_response))
+ )
+
+ postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
+ _, score, _ = get_reward(
+ self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
+ )
+ table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
+
+ if sampling:
+ break
+ df = pd.DataFrame(table)
+
+ if self.accelerator.is_main_process:
+ if is_rich_available():
+ print_rich_table(df.iloc[0 : 0 + 5])
+ if "wandb" in args.report_to:
+ import wandb
+
+ if wandb.run is not None:
+ wandb.log({"completions": wandb.Table(dataframe=df)})
+
+ if "comet_ml" in args.report_to:
+ log_table_to_comet_experiment(
+ name="completions.csv",
+ table=df,
+ )
+
+ # Ensure the model card is saved along with the checkpoint
+ def _save_checkpoint(self, model, trial):
+ if self.args.hub_model_id is None:
+ model_name = Path(self.args.output_dir).name
+ else:
+ model_name = self.args.hub_model_id.split("/")[-1]
+ self.create_model_card(model_name=model_name)
+ super()._save_checkpoint(model, trial)
diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py
index 40d48b82dbf..e38cd190fe8 100644
--- a/trl/trainer/ppo_config.py
+++ b/trl/trainer/ppo_config.py
@@ -12,124 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
-from dataclasses import dataclass, field
-from typing import Literal
+import warnings
+from dataclasses import dataclass
-from ..trainer.utils import OnPolicyConfig
+from ..experimental.ppo import PPOConfig as _PPOConfig
@dataclass
-class PPOConfig(OnPolicyConfig):
- r"""
- Configuration class for the [`PPOTrainer`].
-
- This class includes only the parameters that are specific to PPO training. For a full list of training arguments,
- please refer to the [`~transformers.TrainingArguments`] and [`OnPolicyConfig`] documentation. Note that default
- values in this class may differ from those in [`~transformers.TrainingArguments`].
-
- Using [`~transformers.HfArgumentParser`] we can turn this class into
- [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
- command line.
-
- Parameters:
- exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
- Name of this experiment.
- reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
- Path to the reward model.
- model_adapter_name (`str`, *optional*):
- Name of the train target PEFT adapter, when using LoRA with multiple adapters.
- ref_adapter_name (`str`, *optional*):
- Name of the reference PEFT adapter, when using LoRA with multiple adapters.
- num_ppo_epochs (`int`, *optional*, defaults to `4`):
- Number of epochs to train.
- whiten_rewards (`bool`, *optional*, defaults to `False`):
- Whether to whiten the rewards.
- kl_coef (`float`, *optional*, defaults to `0.05`):
- KL coefficient.
- kl_estimator (`Literal["k1", "k3"]`, *optional*, defaults to `"k1"`):
- Which estimator for KL-Divergence to use from [Approximating KL
- Divergence](http://joschu.net/blog/kl-approx.html). Defaults to "k1", a straightforward, unbiased
- estimator. Can be set to "k3", an unbiased estimator with lower variance which "appears to be a strictly
- better estimator". Cannot be set to "k2", as it is used for logging purposes.
- cliprange (`float`, *optional*, defaults to `0.2`):
- Clip range.
- vf_coef (`float`, *optional*, defaults to `0.1`):
- Value function coefficient.
- cliprange_value (`float`, *optional*, defaults to `0.2`):
- Clip range for the value function.
- gamma (`float`, *optional*, defaults to `1.0`):
- Discount factor.
- lam (`float`, *optional*, defaults to `0.95`):
- Lambda value for GAE.
- ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
- This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
- improving generation speed. However, disabling this option allows training models that exceed the VRAM
- capacity of a single GPU, albeit at the cost of slower generation.
- """
-
- exp_name: str = field(
- default=os.path.basename(__file__)[:-3],
- metadata={"help": "Name of this experiment."},
- )
- reward_model_path: str = field(
- default="EleutherAI/pythia-160m",
- metadata={"help": "Path to the reward model."},
- )
- model_adapter_name: str | None = field(
- default=None,
- metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."},
- )
- ref_adapter_name: str | None = field(
- default=None,
- metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."},
- )
- num_ppo_epochs: int = field(
- default=4,
- metadata={"help": "Number of epochs to train."},
- )
- whiten_rewards: bool = field(
- default=False,
- metadata={"help": "Whether to whiten the rewards."},
- )
- kl_coef: float = field(
- default=0.05,
- metadata={"help": "KL coefficient."},
- )
- kl_estimator: Literal["k1", "k3"] = field(
- default="k1",
- metadata={
- "help": "Which estimator for KL-Divergence to use from Approximating KL Divergence "
- "(http://joschu.net/blog/kl-approx.html). Defaults to 'k1', a straightforward, unbiased estimator. Can be "
- "set to 'k3', an unbiased estimator with lower variance which 'appears to be a strictly better "
- "estimator'. Cannot be set to 'k2', as it is used for logging purposes."
- },
- )
- cliprange: float = field(
- default=0.2,
- metadata={"help": "Clip range."},
- )
- vf_coef: float = field(
- default=0.1,
- metadata={"help": "Value function coefficient."},
- )
- cliprange_value: float = field(
- default=0.2,
- metadata={"help": "Clip range for the value function."},
- )
- gamma: float = field(
- default=1.0,
- metadata={"help": "Discount factor."},
- )
- lam: float = field(
- default=0.95,
- metadata={"help": "Lambda value for GAE."},
- )
- ds3_gather_for_generation: bool = field(
- default=True,
- metadata={
- "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
- "generation, improving generation speed. However, disabling this option allows training models that "
- "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation."
- },
- )
+class PPOConfig(_PPOConfig):
+ def __post_init__(self):
+ warnings.warn(
+ "The `PPOConfig` is now located in `trl.experimental`. Please update your imports to "
+ "`from trl.experimental.ppo import PPOConfig`. The current import path will be removed and no longer "
+ "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223."
+ )
+ super().__post_init__()
diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py
index c4e5182d48f..fb2b94dfa37 100644
--- a/trl/trainer/ppo_trainer.py
+++ b/trl/trainer/ppo_trainer.py
@@ -12,833 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import math
-import os
-import textwrap
-import time
import warnings
-from collections import defaultdict
-from contextlib import contextmanager, nullcontext
-from pathlib import Path
+from dataclasses import dataclass
-import numpy as np
-import pandas as pd
-import torch
-import torch.nn as nn
-from accelerate import Accelerator, logging
-from accelerate.utils import broadcast, gather_object
-from datasets import Dataset
-from torch.utils.data import DataLoader
-from transformers import (
- BaseImageProcessor,
- DataCollatorWithPadding,
- FeatureExtractionMixin,
- GenerationConfig,
- PreTrainedTokenizerBase,
- ProcessorMixin,
- TrainerCallback,
- TrainerControl,
-)
-from transformers.integrations import get_reporting_integration_callbacks
-from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
-from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
-from transformers.utils import is_peft_available, is_rich_available
+from ..experimental.ppo import PPOTrainer as _PPOTrainer
-from ..models import create_reference_model
-from ..models.utils import unwrap_model_for_generation
-from .base_trainer import BaseTrainer
-from .ppo_config import PPOConfig
-from .utils import (
- OnlineTrainerState,
- batch_generation,
- disable_dropout_in_model,
- empty_cache,
- exact_div,
- first_true_indices,
- forward,
- get_reward,
- log_table_to_comet_experiment,
- peft_module_casting_to_bf16,
- prepare_deepspeed,
- print_rich_table,
- selective_log_softmax,
- truncate_response,
-)
-
-logger = logging.get_logger(__name__)
-
-if is_peft_available():
- from peft import PeftConfig, PeftModel, get_peft_model
-
-
-INVALID_LOGPROB = 1.0
-
-
-def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: bool | None = None) -> torch.Tensor:
- """Compute mean of tensor with a masked values."""
- if axis is not None:
- return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
- else:
- return (values * mask).sum() / mask.sum()
-
-
-def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
- """Compute variance of tensor with masked values."""
- mean = masked_mean(values, mask)
- centered_values = values - mean
- variance = masked_mean(centered_values**2, mask)
- if unbiased:
- mask_sum = mask.sum()
- if mask_sum == 0:
- raise ValueError(
- "The sum of the mask is zero, which can happen when `mini_batch_size=1`;"
- "try increase the `mini_batch_size` or `gradient_accumulation_steps`"
- )
- # note that if mask_sum == 1, then there is a division by zero issue
- # to avoid it you just need to use a larger minibatch_size
- bessel_correction = mask_sum / (mask_sum - 1)
- variance = variance * bessel_correction
- return variance
-
-
-def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
- """Whiten values with masked values."""
- mean, var = masked_mean(values, mask), masked_var(values, mask)
- whitened = (values - mean) * torch.rsqrt(var + 1e-8)
- if not shift_mean:
- whitened += mean
- return whitened
-
-
-# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29
-# we did this we can do a single `model = accelerator.prepare(model)`
-class PolicyAndValueWrapper(nn.Module):
- def __init__(self, policy, value_model) -> None:
- super().__init__()
- self.policy = policy
- self.value_model = value_model
- self.critic_backbone = getattr(value_model, value_model.base_model_prefix)
- self.is_gradient_checkpointing = policy.is_gradient_checkpointing
-
- def forward(self, **kwargs):
- output = self.critic_backbone(**kwargs)
- logits = self.value_model.score(output.hidden_states[-1])
- return self.policy(**kwargs), logits
-
-
-class PPOTrainer(BaseTrainer):
- """Trainer for Proximal Policy Optimization (PPO).
-
- For details on PPO, see the paper: [Proximal Policy Optimization
- Algorithms](https://huggingface.co/papers/1707.06347).
-
- Args:
- args ([`PPOConfig`]):
- Training arguments.
- processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`]):
- Class to process the data.
- model (`torch.nn.Module`):
- Model to be trained. This is the policy model.
- ref_model (`torch.nn.Module`, *optional*):
- Reference model used to compute the KL divergence. If `None`, a copy of the policy model is created.
- reward_model (`torch.nn.Module`):
- Reward model used to compute the rewards.
- train_dataset ([`~datasets.Dataset`]):
- Dataset for training.
- value_model (`torch.nn.Module`):
- Value model used to predict the value of a state.
- data_collator ([`~transformers.DataCollatorWithPadding`], *optional*):
- Data collator to batch and pad samples from the dataset. If `None`, a default data collator is created
- using the `processing_class`.
- eval_dataset ([`~datasets.Dataset`] or `dict` of [`~datasets.Dataset`], *optional*):
- Dataset for evaluation.
- optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
- Tuple containing the optimizer and the learning rate scheduler to use for training. If `None`, the
- optimizer and the learning rate scheduler are created using the
- [`~transformers.Trainer.create_optimizer_and_scheduler`] method.
- callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
- Callbacks to use during training.
- peft_config ([`~peft.PeftConfig`], *optional*):
- PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the policy `model`
- will be wrapped with the specified PEFT adapter.
- """
-
- _tag_names = ["trl", "ppo"]
- _name = "PPO"
- _paper = {
- "title": "Fine-Tuning Language Models from Human Preferences",
- "id": "1909.08593",
- # docstyle-ignore
- "citation": textwrap.dedent("""\
- @article{mziegler2019fine-tuning,
- title = {{Fine-Tuning Language Models from Human Preferences}},
- author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
- year = 2019,
- eprint = {arXiv:1909.08593}
- }"""),
- }
-
- def __init__(
- self,
- args: PPOConfig,
- processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin,
- model: nn.Module,
- ref_model: nn.Module | None,
- reward_model: nn.Module,
- train_dataset: Dataset,
- value_model: nn.Module,
- data_collator: DataCollatorWithPadding | None = None,
- eval_dataset: Dataset | dict[str, Dataset] | None = None,
- # less commonly used
- optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
- callbacks: list[TrainerCallback] | None = None,
- peft_config: "PeftConfig | None" = None,
- ) -> None:
- if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
- warnings.warn(
- "This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
- "it and want it to remain, please share your comments here: "
- "https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
- "TRL_EXPERIMENTAL_SILENCE=1."
- )
- if ref_model is model:
- raise ValueError(
- "`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
- "same as `model`, you must make a copy of it, or `None` if you use peft."
- )
-
- self.args = args
- self.processing_class = processing_class
- self.policy_model = model
-
- # Define the collator if not provided
- if data_collator is None:
- data_collator = DataCollatorWithPadding(self.processing_class)
-
- # Handle stop token settings: update policy model's generation_config to use provided stop token
- if args.stop_token and args.stop_token_id:
- raise ValueError("You cannot set both `stop_token` and `stop_token_id`.")
- elif args.stop_token:
- if args.stop_token == "eos":
- self.policy_model.generation_config.eos_token_id = self.stop_token_id = processing_class.eos_token_id
- else:
- raise ValueError(
- f"Unknown `stop_token` {args.stop_token}. Allowed values are: `'eos'` and `None` (no stop token)."
- )
- else:
- self.policy_model.generation_config.eos_token_id = self.stop_token_id = args.stop_token_id # None or int
-
- # Check that the kl estimator is valid
- if self.args.kl_estimator not in {"k1", "k3"}:
- raise ValueError(
- "kl_estimator must be either 'k1' (straightforward, unbiased) or 'k3' (lower variance, unbiased, "
- "appears to be a strictly better estimator). See "
- "[Approximating KL Divergence](http://joschu.net/blog/kl-approx.html) for details."
- )
-
- # peft support
- if not is_peft_available() and peft_config is not None:
- raise ImportError(
- "PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
- )
- elif is_peft_available() and peft_config is not None:
- # if model is a peft model and we have a peft_confg, we merge and unload it first
- if isinstance(self.policy_model, PeftModel):
- self.policy_model = self.policy_model.merge_and_unload()
-
- # get peft model with the given config
- self.policy_model = get_peft_model(self.policy_model, peft_config)
- if args.bf16 and getattr(self.policy_model, "is_loaded_in_4bit", False):
- peft_module_casting_to_bf16(self.policy_model)
-
- self.is_peft_model = is_peft_available() and isinstance(self.policy_model, PeftModel)
- self.model_adapter_name = args.model_adapter_name
- self.ref_adapter_name = args.ref_adapter_name
-
- if ref_model:
- self.ref_model = ref_model
- elif self.is_peft_model:
- self.ref_model = None
- else:
- self.ref_model = create_reference_model(self.policy_model)
-
- self.reward_model = reward_model
- self.train_dataset = train_dataset
- self.train_dataset_len = len(train_dataset)
- self.value_model = value_model
- self.data_collator = data_collator
- self.eval_dataset = eval_dataset
- self.optimizer, self.lr_scheduler = optimizers
- self.optimizer_cls_and_kwargs = None # needed for transformers >= 4.47
-
- #########
- # calculate various batch sizes
- #########
- if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
- args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
- accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
- self.accelerator = accelerator
- args.world_size = accelerator.num_processes
- args.local_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
- args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
- args.batch_size = int(args.local_batch_size * args.world_size)
- args.mini_batch_size = exact_div(
- args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
- )
- args.local_mini_batch_size = exact_div(
- args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
- )
- if args.whiten_rewards:
- assert args.local_mini_batch_size >= 8, (
- f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
- )
- # `per_rank_rollout_batch_size` is our `args.local_batch_size`
- # `per_rank_minibatch_size` is our `args.local_mini_batch_size`
- args.num_total_batches = math.ceil(
- args.total_episodes / args.batch_size
- ) # we may train for more than `total_episodes`
- time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
- time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
- args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
- self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
- if args.num_sample_generations > 0:
- self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
- self.local_dataloader_batch_size = args.local_batch_size
-
- #########
- # setup model, optimizer, and others
- #########
- for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
- if module is not None:
- disable_dropout_in_model(module)
- self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
- self.model.config = self.policy_model.config # needed for pushing to hub
- self.create_optimizer_and_scheduler(
- num_training_steps=args.num_total_batches
- ) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
-
- #########
- # trainer specifics
- #########
- default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
- self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
- self.callback_handler = CallbackHandler(
- self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
- )
- self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
- self.control = TrainerControl()
- self.state = OnlineTrainerState(
- is_local_process_zero=self.is_local_process_zero(),
- is_world_process_zero=self.is_world_process_zero(),
- stateful_callbacks=[
- cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
- ],
- )
- self.current_flos = 0
- self.hp_search_backend = None
- self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
- self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
- # Create distant repo and output directory if needed
- self.hub_model_id = None
- if self.args.push_to_hub:
- self.init_hf_repo()
- if self.args.should_save:
- os.makedirs(self.args.output_dir, exist_ok=True)
-
- # Add tags for models that have been loaded with the correct transformers version
- if hasattr(self.model, "add_model_tags"):
- self.model.add_model_tags(self._tag_names)
-
- #########
- # setup dataloader
- #########
- self.dataloader = DataLoader(
- self.train_dataset,
- batch_size=self.local_dataloader_batch_size,
- shuffle=True,
- collate_fn=self.data_collator,
- drop_last=True, # needed; otherwise the last batch will be of ragged shape
- )
- # sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
- # see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
- torch.manual_seed(args.seed)
- self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
- torch.manual_seed(self.local_seed) # reset the local seed again
-
- self.eval_dataloader = DataLoader(
- self.eval_dataset,
- batch_size=args.per_device_eval_batch_size,
- collate_fn=self.data_collator,
- drop_last=True,
- ) # no need to shuffle eval dataset
- self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
-
- if self.is_deepspeed_enabled:
- self.reward_model = prepare_deepspeed(
- self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
- )
-
- if self.ref_model is None:
- if not self.is_peft_model:
- raise ValueError("No reference model and model is not a Peft model.")
- else:
- self.ref_model = prepare_deepspeed(
- self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
- )
- else:
- if self.ref_model is None:
- if not self.is_peft_model:
- raise ValueError("No reference model and model is not a Peft model.")
- else:
- self.ref_model = self.ref_model.to(self.accelerator.device)
- self.reward_model = self.reward_model.to(self.accelerator.device)
-
- def get_train_dataloader(self) -> DataLoader:
- return self.dataloader
-
- def get_eval_dataloader(self) -> DataLoader:
- return self.eval_dataloader
-
- @contextmanager
- def null_ref_context(self):
- """Context manager for handling null reference model (that is, peft adapter manipulation)."""
- with (
- self.accelerator.unwrap_model(self.model.policy).disable_adapter()
- if self.is_peft_model and not self.ref_adapter_name
- else nullcontext()
- ):
- if self.ref_adapter_name:
- self.model.policy.set_adapter(self.ref_adapter_name)
- yield
- if self.ref_adapter_name:
- self.model.policy.set_adapter(self.model_adapter_name or "default")
-
- def save_model(self, output_dir: str | None = None, _internal_call: bool = False):
- backup_model = self.model
- self.model = self.model.policy # save only the policy
-
- if self.is_deepspeed_enabled:
- backup_deepspeed = self.deepspeed
- self.deepspeed = self.model
-
- super().save_model(output_dir, _internal_call)
-
- self.model = backup_model
-
- if self.is_deepspeed_enabled:
- self.deepspeed = backup_deepspeed
-
- def train(self):
- args = self.args
- accelerator = self.accelerator
- optimizer = self.optimizer
- model = self.model
- ref_policy = self.ref_model
- reward_model = self.reward_model
- processing_class = self.processing_class
- dataloader = self.dataloader
- device = accelerator.device
-
- def repeat_generator():
- while True:
- yield from dataloader
-
- iter_dataloader = iter(repeat_generator())
- generation_config = GenerationConfig(
- max_new_tokens=args.response_length,
- temperature=(args.temperature + 1e-7),
- top_k=0.0,
- top_p=1.0,
- do_sample=True,
- )
-
- accelerator.print("===training policy===")
- start_time = time.time()
- stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
- approxkl_stats = torch.zeros(stats_shape, device=device)
- pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
- pg_loss_stats = torch.zeros(stats_shape, device=device)
- vf_loss_stats = torch.zeros(stats_shape, device=device)
- vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
- entropy_stats = torch.zeros(stats_shape, device=device)
- ratio_stats = torch.zeros(stats_shape, device=device)
- model.train()
-
- # trainer state initialization
- self.state.global_step = 0
- self.state.episode = 0
- self.state.max_steps = args.num_total_batches
- self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
- # Compute absolute values for logging, eval, and save if given as ratio
- if args.logging_steps is not None:
- if args.logging_steps < 1:
- self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
- else:
- self.state.logging_steps = args.logging_steps
- if args.eval_steps is not None:
- if args.eval_steps < 1:
- self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
- else:
- self.state.eval_steps = args.eval_steps
- if args.save_steps is not None:
- if args.save_steps < 1:
- self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
- else:
- self.state.save_steps = args.save_steps
- self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
-
- # backward compatibility
- if self.is_deepspeed_enabled:
- self.deepspeed = self.model
- self.model_wrapped = self.model
-
- for update in range(1, args.num_total_batches + 1):
- self.state.episode += 1 * args.batch_size
- data = next(iter_dataloader)
- with torch.no_grad():
- queries = data["input_ids"].to(device)
- context_length = queries.shape[1]
- responses = []
- postprocessed_responses = []
- logprobs = []
- ref_logprobs = []
- scores = []
- sequence_lengths = []
- values = []
- with unwrap_model_for_generation(
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model:
- query_responses, logitss = batch_generation(
- unwrapped_model.policy,
- queries,
- args.local_rollout_forward_batch_size,
- processing_class.pad_token_id,
- generation_config,
- )
-
- for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
- query = queries[i : i + args.local_rollout_forward_batch_size]
- query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
- response = query_response[:, context_length:]
- logits = logitss[i : i + args.local_rollout_forward_batch_size]
- logprob = selective_log_softmax(logits, response)
- del logits
- empty_cache()
-
- if ref_policy is None:
- with self.null_ref_context():
- ref_output = forward(model.policy, query_response, processing_class.pad_token_id)
- else:
- ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
- ref_logits = ref_output.logits[:, context_length - 1 : -1]
- ref_logits /= args.temperature + 1e-7
- ref_logprob = selective_log_softmax(ref_logits, response)
- del ref_output, ref_logits
- empty_cache()
-
- # Response Processing 1. truncate response after the first occurrence of `stop_token_id`
- postprocessed_response = response
- if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
- postprocessed_response = truncate_response(
- self.stop_token_id, processing_class.pad_token_id, response
- )
-
- # Response Processing 2. run reward model on the truncated responses
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
- sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
- unwrapped_value_model = accelerator.unwrap_model(model).value_model
- full_value, _, _ = get_reward(
- unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
- )
- value = full_value[:, context_length - 1 : -1].squeeze(-1)
- _, score, _ = get_reward(
- reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
- )
-
- responses.append(response)
- postprocessed_responses.append(postprocessed_response)
- logprobs.append(logprob)
- ref_logprobs.append(ref_logprob)
- sequence_lengths.append(sequence_length)
- scores.append(score)
- values.append(value)
- responses = torch.cat(responses, 0)
- postprocessed_responses = torch.cat(postprocessed_responses, 0)
- logprobs = torch.cat(logprobs, 0)
- ref_logprobs = torch.cat(ref_logprobs, 0)
- sequence_lengths = torch.cat(sequence_lengths, 0)
- scores = torch.cat(scores, 0)
- values = torch.cat(values, 0)
- del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
- empty_cache()
- gc.collect()
-
- # Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
- # Completions not passing that filter will receive a lower score.
- contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
- if self.args.missing_eos_penalty is not None:
- scores[~contain_eos_token] -= self.args.missing_eos_penalty
- # accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
-
- # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
- response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
- padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
- logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
- ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
- sequence_lengths_p1 = sequence_lengths + 1
- padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
- values = torch.masked_fill(values, padding_mask_p1, 0)
-
- # 4. compute rewards
- # Formula used by http://joschu.net/blog/kl-approx.html for the k1 and k3 estimators
- logr = ref_logprobs - logprobs
- kl = -logr if args.kl_estimator == "k1" else (logr.exp() - 1) - logr # Else statement is k3
- non_score_reward = -args.kl_coef * kl
- rewards = non_score_reward.clone()
- actual_start = torch.arange(rewards.size(0), device=rewards.device)
- actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
- rewards[actual_start, actual_end] += scores
-
- # 5. whiten rewards
- if args.whiten_rewards:
- rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
- rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
-
- # 6. compute advantages and returns
- lastgaelam = 0
- advantages_reversed = []
- gen_length = responses.shape[1]
- for t in reversed(range(gen_length)):
- nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
- delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
- lastgaelam = delta + args.gamma * args.lam * lastgaelam
- advantages_reversed.append(lastgaelam)
- advantages = torch.stack(advantages_reversed[::-1], axis=1)
- returns = advantages + values
- advantages = masked_whiten(advantages, ~padding_mask)
- advantages = torch.masked_fill(advantages, padding_mask, 0)
- empty_cache()
-
- # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
- for ppo_epoch_idx in range(args.num_ppo_epochs):
- b_inds = np.random.permutation(args.local_batch_size)
- minibatch_idx = 0
- for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
- mini_batch_end = mini_batch_start + args.local_mini_batch_size
- mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
- gradient_accumulation_idx = 0
- for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
- with accelerator.accumulate(model):
- micro_batch_end = micro_batch_start + args.per_device_train_batch_size
- micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
- mb_advantage = advantages[micro_batch_inds]
- mb_responses = responses[micro_batch_inds]
- mb_query_responses = query_responses[micro_batch_inds]
- mb_logprobs = logprobs[micro_batch_inds]
- mb_return = returns[micro_batch_inds]
- mb_values = values[micro_batch_inds]
-
- output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
- logits = output.logits[:, context_length - 1 : -1]
- logits /= args.temperature + 1e-7
- new_logprobs = selective_log_softmax(logits, mb_responses)
- new_logprobs = torch.masked_fill(
- new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
- )
- vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
- vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
- vpredclipped = torch.clamp(
- vpred,
- mb_values - args.cliprange_value,
- mb_values + args.cliprange_value,
- )
- vf_losses1 = torch.square(vpred - mb_return)
- vf_losses2 = torch.square(vpredclipped - mb_return)
- vf_loss_max = torch.max(vf_losses1, vf_losses2)
- vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
- vf_clipfrac = masked_mean(
- (vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
- )
- logprobs_diff = new_logprobs - mb_logprobs
- ratio = torch.exp(logprobs_diff)
- pg_losses = -mb_advantage * ratio
- pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
- pg_loss_max = torch.max(pg_losses, pg_losses2)
- pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
- loss = pg_loss + args.vf_coef * vf_loss
- accelerator.backward(loss)
- optimizer.step()
- optimizer.zero_grad()
- with torch.no_grad():
- pg_clipfrac = masked_mean(
- (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
- )
- prob_dist = torch.nn.functional.softmax(logits, dim=-1)
- entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
- approxkl = 0.5 * (logprobs_diff**2).mean()
- approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
- pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
- pg_clipfrac
- )
- pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
- vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
- vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
- vf_clipfrac
- )
- entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
- ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
- gradient_accumulation_idx += 1
- minibatch_idx += 1
- # del everything and empty cache
- # fmt: off
- del (
- output, vpred_temp, logits, new_logprobs, vpred, vpredclipped,
- vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
- pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
- mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
- )
- # fmt: on
- empty_cache()
- with torch.no_grad():
- mean_kl = kl.sum(1).mean()
- mean_entropy = (-logprobs).sum(1).mean()
- mean_non_score_reward = non_score_reward.sum(1).mean()
- rlhf_reward = mean_non_score_reward + scores.mean()
- eps = int(self.state.episode / (time.time() - start_time))
- metrics = {}
- metrics["eps"] = eps
- metrics["objective/kl"] = self.accelerator.gather_for_metrics(mean_kl).mean().item()
- metrics["objective/entropy"] = self.accelerator.gather_for_metrics(mean_entropy).mean().item()
- metrics["objective/non_score_reward"] = (
- self.accelerator.gather_for_metrics(mean_non_score_reward).mean().item()
- )
- metrics["objective/rlhf_reward"] = self.accelerator.gather_for_metrics(rlhf_reward).mean().item()
- metrics["objective/scores"] = self.accelerator.gather_for_metrics(scores.mean()).mean().item()
- metrics["policy/approxkl_avg"] = self.accelerator.gather_for_metrics(approxkl_stats).mean().item()
- metrics["policy/clipfrac_avg"] = self.accelerator.gather_for_metrics(pg_clipfrac_stats).mean().item()
- metrics["loss/policy_avg"] = self.accelerator.gather_for_metrics(pg_loss_stats).mean().item()
- metrics["loss/value_avg"] = self.accelerator.gather_for_metrics(vf_loss_stats).mean().item()
- metrics["val/clipfrac_avg"] = self.accelerator.gather_for_metrics(vf_clipfrac_stats).mean().item()
- metrics["policy/entropy_avg"] = self.accelerator.gather_for_metrics(entropy_stats).mean().item()
- metrics["val/ratio"] = self.accelerator.gather_for_metrics(ratio_stats).mean().item()
- metrics["val/ratio_var"] = self.accelerator.gather_for_metrics(ratio_stats).var().item()
- metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
- metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
- metrics["episode"] = self.state.episode
- self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
- self.state.global_step += 1
- self.log(metrics)
-
- self.lr_scheduler.step()
- self.control = self.callback_handler.on_step_end(args, self.state, self.control)
- if self.control.should_save:
- self._save_checkpoint(model, trial=None)
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
- del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
- empty_cache()
- gc.collect()
-
- if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
- self.generate_completions(sampling=True)
- empty_cache()
- del (
- query_responses,
- responses,
- postprocessed_responses,
- logprobs,
- ref_logprobs,
- values,
- sequence_lengths,
- contain_eos_token,
- sequence_lengths_p1,
- response_idxs,
- padding_mask,
- padding_mask_p1,
- rewards,
- actual_start,
- actual_end,
- advantages,
- returns,
- )
- empty_cache()
-
- # HF trainer specifics
- self.control = self.callback_handler.on_train_end(args, self.state, self.control)
- if self.control.should_save:
- self._save_checkpoint(model, trial=None)
- self.control = self.callback_handler.on_save(self.args, self.state, self.control)
-
- def generate_completions(self, sampling: bool = False):
- args = self.args
- processing_class = self.processing_class
- generation_config = GenerationConfig(
- max_new_tokens=self.args.response_length,
- temperature=(0.01 + 1e-7),
- top_k=0.0,
- top_p=1.0,
- do_sample=True,
+@dataclass
+class PPOTrainer(_PPOTrainer):
+ def __init__(self, *args, **kwargs):
+ warnings.warn(
+ "The `PPOTrainer` is now located in `trl.experimental`. Please update your imports to "
+ "`from trl.experimental.ppo import PPOTrainer`. The current import path will be removed and no longer "
+ "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223."
)
-
- table = defaultdict(list)
- with unwrap_model_for_generation(
- self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
- ) as unwrapped_model:
- for batch in self.eval_dataloader:
- query = batch["input_ids"]
- with torch.no_grad():
- context_length = query.shape[1]
- query_response, _ = batch_generation(
- unwrapped_model.policy,
- query,
- query.shape[0],
- processing_class.pad_token_id,
- generation_config,
- )
- response = query_response[:, context_length:]
- postprocessed_response = response
- if self.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
- postprocessed_response = truncate_response(
- self.stop_token_id, processing_class.pad_token_id, response
- )
- table["query"].extend(
- gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
- )
- table["model response"].extend(
- gather_object(processing_class.batch_decode(postprocessed_response))
- )
-
- postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
- _, score, _ = get_reward(
- self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
- )
- table["score"].extend(self.accelerator.gather_for_metrics(score).float().cpu().numpy())
-
- if sampling:
- break
- df = pd.DataFrame(table)
-
- if self.accelerator.is_main_process:
- if is_rich_available():
- print_rich_table(df.iloc[0 : 0 + 5])
- if "wandb" in args.report_to:
- import wandb
-
- if wandb.run is not None:
- wandb.log({"completions": wandb.Table(dataframe=df)})
-
- if "comet_ml" in args.report_to:
- log_table_to_comet_experiment(
- name="completions.csv",
- table=df,
- )
-
- # Ensure the model card is saved along with the checkpoint
- def _save_checkpoint(self, model, trial):
- if self.args.hub_model_id is None:
- model_name = Path(self.args.output_dir).name
- else:
- model_name = self.args.hub_model_id.split("/")[-1]
- self.create_model_card(model_name=model_name)
- super()._save_checkpoint(model, trial)
+ super().__init__(*args, **kwargs)