Skip to content
Open
6 changes: 4 additions & 2 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,10 @@
title: GRPO With Replay Buffer
- local: gspo_token
title: GSPO-token
- local: papo_trainer
title: PAPO
- local: merge_model_callback
title: Merge Model Callback
- local: openenv
title: OpenEnv Integration
- local: papo_trainer
title: PAPO
title: Experimental
4 changes: 0 additions & 4 deletions docs/source/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

[[autodoc]] LogCompletionsCallback

## MergeModelCallback

[[autodoc]] MergeModelCallback

## BEMACallback

[[autodoc]] BEMACallback
Expand Down
51 changes: 51 additions & 0 deletions docs/source/merge_model_callback.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Merge Model Callback

This feature allows merging the policy model (the model being trained) with another model during or after training using [mergekit](https://github.com/arcee-ai/mergekit). This can be useful for techniques like model interpolation, task arithmetic, or other model merging strategies.

> **Note**: This is an experimental feature. The Mergekit integration has seen low usage and may be removed in future versions.

## Usage

```python
from trl.experimental.mergekit import MergeConfig, MergeModelCallback
from trl import DPOTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load your dataset and model
pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")

# Configure the merge
config = MergeConfig()

# Create the callback
merge_callback = MergeModelCallback(
merge_config=config,
merge_at_every_checkpoint=False, # Merge only at the end of training
push_to_hub=False, # Optionally push merged model to Hub
)

# Use with any trainer
trainer = DPOTrainer(
model=model,
train_dataset=pref_dataset,
processing_class=tokenizer,
callbacks=[merge_callback],
)

trainer.train()
```

## Installation

The mergekit integration requires the `mergekit` package:

```bash
pip install mergekit
```

## API Reference

[[autodoc]] experimental.mergekit.MergeModelCallback
3 changes: 1 addition & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
DPOConfig,
DPOTrainer,
LogCompletionsCallback,
MergeModelCallback,
WinRateCallback,
)
from trl.mergekit_utils import MergeConfig
from trl.experimental.mergekit import MergeConfig, MergeModelCallback

from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft, require_wandb

Expand Down
2 changes: 0 additions & 2 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
],
"trainer.callbacks": [
"BEMACallback",
"MergeModelCallback",
"RichProgressCallback",
"SyncRefModelCallback",
"WeaveCallback",
Expand Down Expand Up @@ -179,7 +178,6 @@
)
from .trainer.callbacks import (
BEMACallback,
MergeModelCallback,
RichProgressCallback,
SyncRefModelCallback,
WeaveCallback,
Expand Down
31 changes: 31 additions & 0 deletions trl/experimental/mergekit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Mergekit integration for TRL (Experimental).

This module contains utilities for merging models using mergekit.
This is an experimental feature and may be removed in future versions.
"""

from .callbacks import MergeModelCallback
from .mergekit_utils import MergeConfig, merge_models, upload_model_to_hf


__all__ = [
"MergeConfig",
"MergeModelCallback",
"merge_models",
"upload_model_to_hf",
]
82 changes: 82 additions & 0 deletions trl/experimental/mergekit/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Optional

from transformers import TrainerCallback

from ...import_utils import is_mergekit_available
from ...trainer.utils import get_config_model_id
from .mergekit_utils import MergeConfig, merge_models, upload_model_to_hf


class MergeModelCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based
on a merge configuration.

Args:
merge_config ([`MergeConfig`], *optional*):
Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used.
merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`):
Whether to merge the model at every checkpoint.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the merged model to the Hub after merging.

Example:

```python
from trl.experimental.mergekit import MergeConfig, MergeModelCallback

config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
```
"""

def __init__(
self,
merge_config: Optional["MergeConfig"] = None,
merge_at_every_checkpoint: bool = False,
push_to_hub: bool = False,
):
if not is_mergekit_available():
raise ImportError(
"MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`."
)
self.merge_config = merge_config or MergeConfig()
self.merge_at_every_checkpoint = merge_at_every_checkpoint
self.push_to_hub = push_to_hub

def _merge_and_maybe_push(self, output_dir, global_step, model):
checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}")
self.merge_config.policy_model_path = checkpoint_path
if self.merge_config.target_model_path is None:
self.merge_config.target_model_path = get_config_model_id(model.config)
merge_path = os.path.join(checkpoint_path, "merged")

merge_models(self.merge_config.create(), merge_path)

if self.push_to_hub:
repo_name = f"{output_dir}_checkpoint-{global_step}_merged"
upload_model_to_hf(merge_path, repo_name)

def on_save(self, args, state, control, model=None, **kwargs):
if self.merge_at_every_checkpoint:
self._merge_and_maybe_push(args.output_dir, state.global_step, model)

def on_train_end(self, args, state, control, model=None, **kwargs):
if not self.merge_at_every_checkpoint:
self._merge_and_maybe_push(args.output_dir, state.global_step, model)
File renamed without changes.
2 changes: 0 additions & 2 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"callbacks": [
"BEMACallback",
"LogCompletionsCallback",
"MergeModelCallback",
"RichProgressCallback",
"SyncRefModelCallback",
"WeaveCallback",
Expand Down Expand Up @@ -83,7 +82,6 @@
from .callbacks import (
BEMACallback,
LogCompletionsCallback,
MergeModelCallback,
RichProgressCallback,
SyncRefModelCallback,
WeaveCallback,
Expand Down
69 changes: 3 additions & 66 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import logging
import os

import pandas as pd
import torch
Expand All @@ -34,11 +33,10 @@
from transformers.utils import is_rich_available

from ..data_utils import maybe_apply_chat_template
from ..import_utils import is_mergekit_available, is_weave_available
from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf
from ..import_utils import is_weave_available
from ..models.utils import unwrap_model_for_generation
from .judges import BasePairwiseJudge
from .utils import get_config_model_id, log_table_to_comet_experiment
from .utils import log_table_to_comet_experiment


if is_rich_available():
Expand Down Expand Up @@ -227,7 +225,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
table.add_column("Value", justify="right")

for metric, val in metrics.items():
formatted = f"{val:.3f}" if isinstance(val, (float, int)) else str(val)
formatted = f"{val:.3f}" if isinstance(val, float | int) else str(val)
table.add_row(metric, formatted)

tables.append(Panel(table, border_style="cyan", padding=(0, 1)))
Expand Down Expand Up @@ -776,67 +774,6 @@ def on_evaluate(self, args, state, control, **kwargs):
self._last_logged_step = state.global_step


class MergeModelCallback(TrainerCallback):
r"""
A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based
on a merge configuration.

Args:
merge_config ([`MergeConfig`], *optional*):
Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used.
merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`):
Whether to merge the model at every checkpoint.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the merged model to the Hub after merging.

Example:

```python
from trl.mergekit_utils import MergeConfig
from trl import MergeModelCallback

config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
```
"""

def __init__(
self,
merge_config: "MergeConfig | None" = None,
merge_at_every_checkpoint: bool = False,
push_to_hub: bool = False,
):
if not is_mergekit_available():
raise ImportError(
"MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`."
)
self.merge_config = merge_config or MergeConfig()
self.merge_at_every_checkpoint = merge_at_every_checkpoint
self.push_to_hub = push_to_hub

def _merge_and_maybe_push(self, output_dir, global_step, model):
checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}")
self.merge_config.policy_model_path = checkpoint_path
if self.merge_config.target_model_path is None:
self.merge_config.target_model_path = get_config_model_id(model.config)
merge_path = os.path.join(checkpoint_path, "merged")

merge_models(self.merge_config.create(), merge_path)

if self.push_to_hub:
repo_name = f"{output_dir}_checkpoint-{global_step}_merged"
upload_model_to_hf(merge_path, repo_name)

def on_save(self, args, state, control, model=None, **kwargs):
if self.merge_at_every_checkpoint:
self._merge_and_maybe_push(args.output_dir, state.global_step, model)

def on_train_end(self, args, state, control, model=None, **kwargs):
if not self.merge_at_every_checkpoint:
self._merge_and_maybe_push(args.output_dir, state.global_step, model)


class BEMACallback(TrainerCallback):
# docstyle-ignore
r"""
Expand Down
Loading