Skip to content
Merged
4 changes: 2 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@
title: GRPO
- local: kto_trainer
title: KTO
- local: nash_md_trainer
title: Nash-MD
- local: orpo_trainer
title: ORPO
- local: ppo_trainer
Expand Down Expand Up @@ -117,6 +115,8 @@
title: Judges
- local: minillm
title: MiniLLM
- local: nash_md_trainer
title: Nash-MD
- local: papo_trainer
title: PAPO
- local: xpo_trainer
Expand Down
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`experimental.gkd.GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`experimental.nash_md.NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/example_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl
| [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
| [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. |
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`experimental.nash_md.NashMDTrainer`] to fine-tune a model. |
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
Expand Down
2 changes: 1 addition & 1 deletion docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
- [`GRPOTrainer`] ⚡️
- [`RLOOTrainer`] ⚡️
- [`OnlineDPOTrainer`] ⚡️
- [`NashMDTrainer`] ⚡️
- [`PPOTrainer`]
- [`experimental.nash_md.NashMDTrainer`] 🧪 ⚡️
- [`experimental.xpo.XPOTrainer`] 🧪 ⚡️

### Reward modeling
Expand Down
14 changes: 7 additions & 7 deletions docs/source/nash_md_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ Below is the script to train the model:
```python
# train_nash_md.py
from datasets import load_dataset
from trl import NashMDConfig, NashMDTrainer
from trl.experimental.judges import PairRMJudge
from trl.experimental.nash_md import NashMDConfig, NashMDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
Expand Down Expand Up @@ -64,7 +64,7 @@ The best programming language depends on personal preference, the complexity of

## Expected dataset type

Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`experimental.nash_md.NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.

## Usage tips

Expand All @@ -91,7 +91,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht

### Encourage EOS token generation

We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`NashMDConfig`]:
We may want the model to generate completions within a given length. During training, the model will generate completions up to the maximum length specified in the `max_new_tokens` argument of [`experimental.nash_md.NashMDConfig`]. If you want to penalize the model for not generating an EOS token before reaching the maximum length, you can use the `missing_eos_penalty` argument of [`experimental.nash_md.NashMDConfig`]:

```python
training_args = NashMDConfig(..., max_new_tokens=128, missing_eos_penalty=1.0)
Expand Down Expand Up @@ -144,16 +144,16 @@ While training and evaluating, we record the following reward metrics:
* `logps/rejected`: The mean log probabilities of the reference completions.
* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token.
* `val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token.
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].
* `mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`NashMDConfig`].
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.nash_md.NashMDConfig`].
* `mixture_coef`: Logit mixture coefficient for the model and reference model. Typically fixed, but can be made dynamic by passing a list to [`experimental.nash_md.NashMDConfig`].

## NashMDTrainer

[[autodoc]] NashMDTrainer
[[autodoc]] experimental.nash_md.NashMDTrainer
- train
- save_model
- push_to_hub

## NashMDConfig

[[autodoc]] NashMDConfig
[[autodoc]] experimental.nash_md.NashMDConfig
10 changes: 5 additions & 5 deletions docs/source/vllm_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ This document will guide you through the process of using vLLM with TRL for fast
>
> - [`GRPOTrainer`]
> - [`OnlineDPOTrainer`]
> - [`NashMDTrainer`]
> - [`experimental.xpo.XPOTrainer`]
> - [`RLOOTrainer`]
> - [`experimental.nash_md.NashMDTrainer`]
> - [`experimental.xpo.XPOTrainer`]

## 🚀 How can I use vLLM with TRL to speed up training?

Expand Down Expand Up @@ -105,7 +105,7 @@ trainer.train()

```python
from datasets import load_dataset
from trl import NashMDTrainer, NashMDConfig
from trl.experimental.nash_md import NashMDConfig, NashMDTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

Expand Down Expand Up @@ -379,7 +379,7 @@ training_args = OnlineDPOConfig(
<hfoption id="NashMD">

```python
from trl import NashMDConfig
from trl.experimental.nash_md import NashMDConfig

training_args = NashMDConfig(
...,
Expand Down Expand Up @@ -454,7 +454,7 @@ training_args = OnlineDPOConfig(
<hfoption id="NashMD">

```python
from trl import NashMDConfig
from trl.experimental.nash_md import NashMDConfig

training_args = NashMDConfig(
...,
Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@
from trl import (
LogCompletionsCallback,
ModelConfig,
NashMDConfig,
NashMDTrainer,
ScriptArguments,
TrlParser,
get_kbit_device_map,
get_quantization_config,
)
from trl.experimental.judges import HfPairwiseJudge, OpenAIPairwiseJudge, PairRMJudge
from trl.experimental.nash_md import NashMDConfig, NashMDTrainer


# Enable logging in a Hugging Face Space
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.utils import is_peft_available

from trl import NashMDConfig, NashMDTrainer
from trl.experimental.nash_md import NashMDConfig, NashMDTrainer

from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft
from ..testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft


if is_peft_available():
Expand Down
23 changes: 23 additions & 0 deletions tests/experimental/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from trl.experimental.bco import BCOConfig, BCOTrainer
from trl.experimental.cpo import CPOConfig, CPOTrainer
from trl.experimental.nash_md import NashMDConfig, NashMDTrainer
from trl.experimental.xpo import XPOConfig, XPOTrainer

from ..testing_utils import TrlTestCase, require_sklearn
Expand Down Expand Up @@ -113,6 +114,28 @@ def test_cpo(self):
assert trainer.args.model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.dataset_num_proc == 4

@pytest.mark.parametrize("mixtures_coef_list", [False, True])
def test_nash_md(self, mixtures_coef_list):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
ref_model = AutoModelForCausalLM.from_pretrained(model_id)
reward_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1)
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = NashMDConfig(
self.tmp_dir,
mixture_coef=0.5 if not mixtures_coef_list else [0.5, 0.6],
)
trainer = NashMDTrainer(
args=training_args,
processing_class=tokenizer,
model=model,
ref_model=ref_model,
reward_funcs=reward_model,
train_dataset=dataset,
)
assert trainer.args.mixture_coef == (0.5 if not mixtures_coef_list else [0.5, 0.6])

@pytest.mark.parametrize("alpha_list", [False, True])
def test_xpo(self, alpha_list):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
Expand Down
24 changes: 0 additions & 24 deletions tests/test_trainers_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
FDivergenceType,
KTOConfig,
KTOTrainer,
NashMDConfig,
NashMDTrainer,
OnlineDPOConfig,
OnlineDPOTrainer,
ORPOConfig,
Expand Down Expand Up @@ -150,28 +148,6 @@ def test_kto(self):
assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True}
assert trainer.args.dataset_num_proc == 4

@pytest.mark.parametrize("mixtures_coef_list", [False, True])
def test_nash_md(self, mixtures_coef_list):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
ref_model = AutoModelForCausalLM.from_pretrained(model_id)
reward_model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1)
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = NashMDConfig(
self.tmp_dir,
mixture_coef=0.5 if not mixtures_coef_list else [0.5, 0.6],
)
trainer = NashMDTrainer(
args=training_args,
processing_class=tokenizer,
model=model,
ref_model=ref_model,
reward_funcs=reward_model,
train_dataset=dataset,
)
assert trainer.args.mixture_coef == (0.5 if not mixtures_coef_list else [0.5, 0.6])

@pytest.mark.parametrize("beta_list", [False, True])
def test_online_dpo(self, beta_list):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
Expand Down
19 changes: 19 additions & 0 deletions trl/experimental/nash_md/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .nash_md_config import NashMDConfig
from .nash_md_trainer import NashMDTrainer


__all__ = ["NashMDConfig", "NashMDTrainer"]
46 changes: 46 additions & 0 deletions trl/experimental/nash_md/nash_md_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field

from ...trainer.online_dpo_config import OnlineDPOConfig


@dataclass
class NashMDConfig(OnlineDPOConfig):
r"""
Configuration class for the [`experimental.nash_md.NashMDTrainer`].

Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:

Parameters:
mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
epochs.
"""

mixture_coef: list[float] = field(
default_factory=lambda: [0.5],
metadata={
"help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided "
"then the mixture coefficient is selected for each new epoch and the last coefficient is used for the "
"rest of the epochs."
},
)

def __post_init__(self):
super().__post_init__()
if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1:
self.mixture_coef = self.mixture_coef[0]
Loading
Loading