From fb0a5b034a0d92942e633f4326a4411f49ffd0f4 Mon Sep 17 00:00:00 2001 From: Behrooz Date: Sun, 2 Nov 2025 19:37:38 -0800 Subject: [PATCH 1/4] fix: Remove chat template setting from non-SFT trainer scripts Resolves #4404 - Remove SIMPLE_CHAT_TEMPLATE import from 7 trainer scripts - Remove chat template setting for non-SFT trainers (DPO, CPO, ORPO, PPO, Nash-MD, XPO, Online DPO) - Chat templates only make sense for SFT (instruction tuning), not for preference optimization or reward-based training - Scripts modified: - examples/scripts/online_dpo.py - examples/scripts/orpo.py - examples/scripts/cpo.py - examples/scripts/nash_md.py - examples/scripts/xpo.py - examples/scripts/ppo/ppo.py - examples/scripts/ppo/ppo_tldr.py --- examples/scripts/cpo.py | 3 --- examples/scripts/nash_md.py | 3 --- examples/scripts/online_dpo.py | 3 --- examples/scripts/orpo.py | 3 --- examples/scripts/ppo/ppo.py | 3 --- examples/scripts/ppo/ppo_tldr.py | 3 --- examples/scripts/xpo.py | 3 --- 7 files changed, 21 deletions(-) diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 2d9049136c6..fef9cdf1247 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -64,7 +64,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser from trl import CPOConfig, CPOTrainer, ModelConfig, ScriptArguments, get_peft_config -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE # Enable logging in a Hugging Face Space @@ -90,8 +89,6 @@ # Dataset ################ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE ################ # Training diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index e3f1486c75d..fdb8ca09a3e 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -73,7 +73,6 @@ get_kbit_device_map, get_quantization_config, ) -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE # Enable logging in a Hugging Face Space @@ -128,8 +127,6 @@ ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) diff --git a/examples/scripts/online_dpo.py b/examples/scripts/online_dpo.py index 91569c8b4f1..4ed7afe884d 100644 --- a/examples/scripts/online_dpo.py +++ b/examples/scripts/online_dpo.py @@ -69,7 +69,6 @@ get_peft_config, get_quantization_config, ) -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE # Enable logging in a Hugging Face Space @@ -131,8 +130,6 @@ trust_remote_code=model_args.trust_remote_code, **model_kwargs, ) - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index d392bb7bf1b..e256a4277ad 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -64,7 +64,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE # Enable logging in a Hugging Face Space @@ -91,8 +90,6 @@ # Dataset ################ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE ################ # Training diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index 5dfcda55429..2f5471996c2 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -43,7 +43,6 @@ get_peft_config, get_quantization_config, ) -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE # Enable logging in a Hugging Face Space @@ -106,8 +105,6 @@ model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 4ef1cf4e7b6..7962758ec40 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -43,7 +43,6 @@ get_peft_config, get_quantization_config, ) -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE # Enable logging in a Hugging Face Space @@ -113,8 +112,6 @@ model_args.model_name_or_path, padding_side="left", trust_remote_code=model_args.trust_remote_code ) tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE value_model = AutoModelForSequenceClassification.from_pretrained( training_args.reward_model_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 ) diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index e4e7c6301a6..70c13226c5d 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -57,7 +57,6 @@ get_kbit_device_map, get_quantization_config, ) -from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE # Enable logging in a Hugging Face Space @@ -113,8 +112,6 @@ ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - if tokenizer.chat_template is None: - tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) From e955b3d11b2cc7441057bf2c999e510d5919d489 Mon Sep 17 00:00:00 2001 From: Behrooz Date: Sun, 2 Nov 2025 20:02:52 -0800 Subject: [PATCH 2/4] refactor: Move Mergekit integration to experimental submodule Resolves #4395 - Move mergekit_utils.py to trl/experimental/mergekit/ - Move MergeModelCallback to trl/experimental/mergekit/callbacks.py - Create trl/experimental/mergekit/__init__.py for module exports - Remove MergeModelCallback from main trl exports (trl/__init__.py, trl/trainer/__init__.py) - Remove MergeModelCallback from trl/trainer/callbacks.py - Update test imports to use trl.experimental.mergekit - Update callbacks documentation with experimental status note This change consolidates underutilized Mergekit functionality into the experimental namespace as part of v1.0 preparations. The module may be deprecated or removed in future versions if usage does not materialize. To use MergeModelCallback after this change: ```python from trl.experimental.mergekit import MergeConfig, MergeModelCallback ``` --- docs/source/callbacks.md | 22 ++++- tests/test_callbacks.py | 3 +- trl/__init__.py | 2 - trl/experimental/mergekit/__init__.py | 31 +++++++ trl/experimental/mergekit/callbacks.py | 82 +++++++++++++++++++ .../mergekit}/mergekit_utils.py | 0 trl/trainer/__init__.py | 2 - trl/trainer/callbacks.py | 66 +-------------- 8 files changed, 134 insertions(+), 74 deletions(-) create mode 100644 trl/experimental/mergekit/__init__.py create mode 100644 trl/experimental/mergekit/callbacks.py rename trl/{ => experimental/mergekit}/mergekit_utils.py (100%) diff --git a/docs/source/callbacks.md b/docs/source/callbacks.md index b89de277309..f19d292324d 100644 --- a/docs/source/callbacks.md +++ b/docs/source/callbacks.md @@ -16,10 +16,6 @@ [[autodoc]] LogCompletionsCallback -## MergeModelCallback - -[[autodoc]] MergeModelCallback - ## BEMACallback [[autodoc]] BEMACallback @@ -27,3 +23,21 @@ ## WeaveCallback [[autodoc]] WeaveCallback + +## Experimental Callbacks + +Some callbacks have been moved to the experimental submodule due to low usage and potential removal in future versions. + +### MergeModelCallback + +`MergeModelCallback` has been moved to `trl.experimental.mergekit`. To use it: + +```python +from trl.experimental.mergekit import MergeModelCallback, MergeConfig + +config = MergeConfig() +merge_callback = MergeModelCallback(config) +trainer = DPOTrainer(..., callbacks=[merge_callback]) +``` + +For more details, see the [mergekit integration documentation](https://github.com/arcee-ai/mergekit). diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 811bcf79f37..a0e88318808 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -27,10 +27,9 @@ DPOConfig, DPOTrainer, LogCompletionsCallback, - MergeModelCallback, WinRateCallback, ) -from trl.mergekit_utils import MergeConfig +from trl.experimental.mergekit import MergeConfig, MergeModelCallback from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft, require_wandb diff --git a/trl/__init__.py b/trl/__init__.py index ee6aa770b8e..5e4b43e45df 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -110,7 +110,6 @@ ], "trainer.callbacks": [ "BEMACallback", - "MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback", "WeaveCallback", @@ -191,7 +190,6 @@ ) from .trainer.callbacks import ( BEMACallback, - MergeModelCallback, RichProgressCallback, SyncRefModelCallback, WeaveCallback, diff --git a/trl/experimental/mergekit/__init__.py b/trl/experimental/mergekit/__init__.py new file mode 100644 index 00000000000..2d8f59384a1 --- /dev/null +++ b/trl/experimental/mergekit/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mergekit integration for TRL (Experimental). + +This module contains utilities for merging models using mergekit. +This is an experimental feature and may be removed in future versions. +""" + +from .callbacks import MergeModelCallback +from .mergekit_utils import MergeConfig, merge_models, upload_model_to_hf + + +__all__ = [ + "MergeConfig", + "MergeModelCallback", + "merge_models", + "upload_model_to_hf", +] diff --git a/trl/experimental/mergekit/callbacks.py b/trl/experimental/mergekit/callbacks.py new file mode 100644 index 00000000000..bab73e34ca3 --- /dev/null +++ b/trl/experimental/mergekit/callbacks.py @@ -0,0 +1,82 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional + +from transformers import TrainerCallback + +from ...import_utils import is_mergekit_available +from ...trainer.utils import get_config_model_id +from .mergekit_utils import MergeConfig, merge_models, upload_model_to_hf + + +class MergeModelCallback(TrainerCallback): + r""" + A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based + on a merge configuration. + + Args: + merge_config ([`MergeConfig`], *optional*): + Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used. + merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`): + Whether to merge the model at every checkpoint. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the merged model to the Hub after merging. + + Example: + + ```python + from trl.experimental.mergekit import MergeConfig, MergeModelCallback + + config = MergeConfig() + merge_callback = MergeModelCallback(config) + trainer = DPOTrainer(..., callbacks=[merge_callback]) + ``` + """ + + def __init__( + self, + merge_config: Optional["MergeConfig"] = None, + merge_at_every_checkpoint: bool = False, + push_to_hub: bool = False, + ): + if not is_mergekit_available(): + raise ImportError( + "MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`." + ) + self.merge_config = merge_config or MergeConfig() + self.merge_at_every_checkpoint = merge_at_every_checkpoint + self.push_to_hub = push_to_hub + + def _merge_and_maybe_push(self, output_dir, global_step, model): + checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}") + self.merge_config.policy_model_path = checkpoint_path + if self.merge_config.target_model_path is None: + self.merge_config.target_model_path = get_config_model_id(model.config) + merge_path = os.path.join(checkpoint_path, "merged") + + merge_models(self.merge_config.create(), merge_path) + + if self.push_to_hub: + repo_name = f"{output_dir}_checkpoint-{global_step}_merged" + upload_model_to_hf(merge_path, repo_name) + + def on_save(self, args, state, control, model=None, **kwargs): + if self.merge_at_every_checkpoint: + self._merge_and_maybe_push(args.output_dir, state.global_step, model) + + def on_train_end(self, args, state, control, model=None, **kwargs): + if not self.merge_at_every_checkpoint: + self._merge_and_maybe_push(args.output_dir, state.global_step, model) diff --git a/trl/mergekit_utils.py b/trl/experimental/mergekit/mergekit_utils.py similarity index 100% rename from trl/mergekit_utils.py rename to trl/experimental/mergekit/mergekit_utils.py diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 98846bf7159..fac99e9711c 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -23,7 +23,6 @@ "callbacks": [ "BEMACallback", "LogCompletionsCallback", - "MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback", "WeaveCallback", @@ -83,7 +82,6 @@ from .callbacks import ( BEMACallback, LogCompletionsCallback, - MergeModelCallback, RichProgressCallback, SyncRefModelCallback, WeaveCallback, diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 49cf4ab77b0..95458d3a266 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -36,11 +36,10 @@ from transformers.utils import is_rich_available from ..data_utils import maybe_apply_chat_template -from ..import_utils import is_mergekit_available, is_weave_available -from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf +from ..import_utils import is_weave_available from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge -from .utils import get_config_model_id, log_table_to_comet_experiment +from .utils import log_table_to_comet_experiment if is_rich_available(): @@ -778,67 +777,6 @@ def on_evaluate(self, args, state, control, **kwargs): self._last_logged_step = state.global_step -class MergeModelCallback(TrainerCallback): - r""" - A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based - on a merge configuration. - - Args: - merge_config ([`MergeConfig`], *optional*): - Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used. - merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`): - Whether to merge the model at every checkpoint. - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether to push the merged model to the Hub after merging. - - Example: - - ```python - from trl.mergekit_utils import MergeConfig - from trl import MergeModelCallback - - config = MergeConfig() - merge_callback = MergeModelCallback(config) - trainer = DPOTrainer(..., callbacks=[merge_callback]) - ``` - """ - - def __init__( - self, - merge_config: Optional["MergeConfig"] = None, - merge_at_every_checkpoint: bool = False, - push_to_hub: bool = False, - ): - if not is_mergekit_available(): - raise ImportError( - "MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`." - ) - self.merge_config = merge_config or MergeConfig() - self.merge_at_every_checkpoint = merge_at_every_checkpoint - self.push_to_hub = push_to_hub - - def _merge_and_maybe_push(self, output_dir, global_step, model): - checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}") - self.merge_config.policy_model_path = checkpoint_path - if self.merge_config.target_model_path is None: - self.merge_config.target_model_path = get_config_model_id(model.config) - merge_path = os.path.join(checkpoint_path, "merged") - - merge_models(self.merge_config.create(), merge_path) - - if self.push_to_hub: - repo_name = f"{output_dir}_checkpoint-{global_step}_merged" - upload_model_to_hf(merge_path, repo_name) - - def on_save(self, args, state, control, model=None, **kwargs): - if self.merge_at_every_checkpoint: - self._merge_and_maybe_push(args.output_dir, state.global_step, model) - - def on_train_end(self, args, state, control, model=None, **kwargs): - if not self.merge_at_every_checkpoint: - self._merge_and_maybe_push(args.output_dir, state.global_step, model) - - class BEMACallback(TrainerCallback): # docstyle-ignore r""" From c2ed9e1adacaeb7f194e936b38486be7c36842e1 Mon Sep 17 00:00:00 2001 From: Behrooz Date: Tue, 4 Nov 2025 18:14:26 -0800 Subject: [PATCH 3/4] docs: Move MergeModelCallback to dedicated experimental documentation page - Created docs/source/merge_model_callback.md following BEMA pattern - Added merge_model_callback to experimental section in _toctree.yml - Removed inline experimental section from callbacks.md This addresses reviewer feedback to have a dedicated documentation page under Experimental section (like BEMA) instead of inline docs. --- docs/source/_toctree.yml | 6 ++-- docs/source/callbacks.md | 18 ---------- docs/source/merge_model_callback.md | 51 +++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 20 deletions(-) create mode 100644 docs/source/merge_model_callback.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 74f79544b33..8727f5d150a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -117,8 +117,10 @@ title: GRPO With Replay Buffer - local: gspo_token title: GSPO-token - - local: papo_trainer - title: PAPO + - local: merge_model_callback + title: Merge Model Callback - local: openenv title: OpenEnv Integration + - local: papo_trainer + title: PAPO title: Experimental \ No newline at end of file diff --git a/docs/source/callbacks.md b/docs/source/callbacks.md index f19d292324d..16c1dd6a584 100644 --- a/docs/source/callbacks.md +++ b/docs/source/callbacks.md @@ -23,21 +23,3 @@ ## WeaveCallback [[autodoc]] WeaveCallback - -## Experimental Callbacks - -Some callbacks have been moved to the experimental submodule due to low usage and potential removal in future versions. - -### MergeModelCallback - -`MergeModelCallback` has been moved to `trl.experimental.mergekit`. To use it: - -```python -from trl.experimental.mergekit import MergeModelCallback, MergeConfig - -config = MergeConfig() -merge_callback = MergeModelCallback(config) -trainer = DPOTrainer(..., callbacks=[merge_callback]) -``` - -For more details, see the [mergekit integration documentation](https://github.com/arcee-ai/mergekit). diff --git a/docs/source/merge_model_callback.md b/docs/source/merge_model_callback.md new file mode 100644 index 00000000000..7c52792488a --- /dev/null +++ b/docs/source/merge_model_callback.md @@ -0,0 +1,51 @@ +# Merge Model Callback + +This feature allows merging the policy model (the model being trained) with another model during or after training using [mergekit](https://github.com/arcee-ai/mergekit). This can be useful for techniques like model interpolation, task arithmetic, or other model merging strategies. + +> **Note**: This is an experimental feature. The Mergekit integration has seen low usage and may be removed in future versions. + +## Usage + +```python +from trl.experimental.mergekit import MergeConfig, MergeModelCallback +from trl import DPOTrainer +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Load your dataset and model +pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") +model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") +tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + +# Configure the merge +config = MergeConfig() + +# Create the callback +merge_callback = MergeModelCallback( + merge_config=config, + merge_at_every_checkpoint=False, # Merge only at the end of training + push_to_hub=False, # Optionally push merged model to Hub +) + +# Use with any trainer +trainer = DPOTrainer( + model=model, + train_dataset=pref_dataset, + processing_class=tokenizer, + callbacks=[merge_callback], +) + +trainer.train() +``` + +## Installation + +The mergekit integration requires the `mergekit` package: + +```bash +pip install mergekit +``` + +## API Reference + +[[autodoc]] experimental.mergekit.MergeModelCallback From cd5456ea02767d5b53f1f73b5213c053bd7edfad Mon Sep 17 00:00:00 2001 From: Behrooz Date: Wed, 5 Nov 2025 20:20:39 -0800 Subject: [PATCH 4/4] Fix ruff linting errors in callbacks.py --- trl/trainer/callbacks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index fc3dde142d3..4b2c432614c 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -import os import pandas as pd import torch @@ -226,7 +225,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): table.add_column("Value", justify="right") for metric, val in metrics.items(): - formatted = f"{val:.3f}" if isinstance(val, (float, int)) else str(val) + formatted = f"{val:.3f}" if isinstance(val, float | int) else str(val) table.add_row(metric, formatted) tables.append(Panel(table, border_style="cyan", padding=(0, 1)))