diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 74f79544b33..8727f5d150a 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -117,8 +117,10 @@ title: GRPO With Replay Buffer - local: gspo_token title: GSPO-token - - local: papo_trainer - title: PAPO + - local: merge_model_callback + title: Merge Model Callback - local: openenv title: OpenEnv Integration + - local: papo_trainer + title: PAPO title: Experimental \ No newline at end of file diff --git a/docs/source/callbacks.md b/docs/source/callbacks.md index b89de277309..16c1dd6a584 100644 --- a/docs/source/callbacks.md +++ b/docs/source/callbacks.md @@ -16,10 +16,6 @@ [[autodoc]] LogCompletionsCallback -## MergeModelCallback - -[[autodoc]] MergeModelCallback - ## BEMACallback [[autodoc]] BEMACallback diff --git a/docs/source/merge_model_callback.md b/docs/source/merge_model_callback.md new file mode 100644 index 00000000000..7c52792488a --- /dev/null +++ b/docs/source/merge_model_callback.md @@ -0,0 +1,51 @@ +# Merge Model Callback + +This feature allows merging the policy model (the model being trained) with another model during or after training using [mergekit](https://github.com/arcee-ai/mergekit). This can be useful for techniques like model interpolation, task arithmetic, or other model merging strategies. + +> **Note**: This is an experimental feature. The Mergekit integration has seen low usage and may be removed in future versions. + +## Usage + +```python +from trl.experimental.mergekit import MergeConfig, MergeModelCallback +from trl import DPOTrainer +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +# Load your dataset and model +pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") +model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") +tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + +# Configure the merge +config = MergeConfig() + +# Create the callback +merge_callback = MergeModelCallback( + merge_config=config, + merge_at_every_checkpoint=False, # Merge only at the end of training + push_to_hub=False, # Optionally push merged model to Hub +) + +# Use with any trainer +trainer = DPOTrainer( + model=model, + train_dataset=pref_dataset, + processing_class=tokenizer, + callbacks=[merge_callback], +) + +trainer.train() +``` + +## Installation + +The mergekit integration requires the `mergekit` package: + +```bash +pip install mergekit +``` + +## API Reference + +[[autodoc]] experimental.mergekit.MergeModelCallback diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index b5323e9292a..76e4bec8b97 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -27,10 +27,9 @@ DPOConfig, DPOTrainer, LogCompletionsCallback, - MergeModelCallback, WinRateCallback, ) -from trl.mergekit_utils import MergeConfig +from trl.experimental.mergekit import MergeConfig, MergeModelCallback from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft, require_wandb diff --git a/trl/__init__.py b/trl/__init__.py index 44228d2092e..37c01cbe671 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -98,7 +98,6 @@ ], "trainer.callbacks": [ "BEMACallback", - "MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback", "WeaveCallback", @@ -179,7 +178,6 @@ ) from .trainer.callbacks import ( BEMACallback, - MergeModelCallback, RichProgressCallback, SyncRefModelCallback, WeaveCallback, diff --git a/trl/experimental/mergekit/__init__.py b/trl/experimental/mergekit/__init__.py new file mode 100644 index 00000000000..2d8f59384a1 --- /dev/null +++ b/trl/experimental/mergekit/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Mergekit integration for TRL (Experimental). + +This module contains utilities for merging models using mergekit. +This is an experimental feature and may be removed in future versions. +""" + +from .callbacks import MergeModelCallback +from .mergekit_utils import MergeConfig, merge_models, upload_model_to_hf + + +__all__ = [ + "MergeConfig", + "MergeModelCallback", + "merge_models", + "upload_model_to_hf", +] diff --git a/trl/experimental/mergekit/callbacks.py b/trl/experimental/mergekit/callbacks.py new file mode 100644 index 00000000000..bab73e34ca3 --- /dev/null +++ b/trl/experimental/mergekit/callbacks.py @@ -0,0 +1,82 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional + +from transformers import TrainerCallback + +from ...import_utils import is_mergekit_available +from ...trainer.utils import get_config_model_id +from .mergekit_utils import MergeConfig, merge_models, upload_model_to_hf + + +class MergeModelCallback(TrainerCallback): + r""" + A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based + on a merge configuration. + + Args: + merge_config ([`MergeConfig`], *optional*): + Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used. + merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`): + Whether to merge the model at every checkpoint. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether to push the merged model to the Hub after merging. + + Example: + + ```python + from trl.experimental.mergekit import MergeConfig, MergeModelCallback + + config = MergeConfig() + merge_callback = MergeModelCallback(config) + trainer = DPOTrainer(..., callbacks=[merge_callback]) + ``` + """ + + def __init__( + self, + merge_config: Optional["MergeConfig"] = None, + merge_at_every_checkpoint: bool = False, + push_to_hub: bool = False, + ): + if not is_mergekit_available(): + raise ImportError( + "MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`." + ) + self.merge_config = merge_config or MergeConfig() + self.merge_at_every_checkpoint = merge_at_every_checkpoint + self.push_to_hub = push_to_hub + + def _merge_and_maybe_push(self, output_dir, global_step, model): + checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}") + self.merge_config.policy_model_path = checkpoint_path + if self.merge_config.target_model_path is None: + self.merge_config.target_model_path = get_config_model_id(model.config) + merge_path = os.path.join(checkpoint_path, "merged") + + merge_models(self.merge_config.create(), merge_path) + + if self.push_to_hub: + repo_name = f"{output_dir}_checkpoint-{global_step}_merged" + upload_model_to_hf(merge_path, repo_name) + + def on_save(self, args, state, control, model=None, **kwargs): + if self.merge_at_every_checkpoint: + self._merge_and_maybe_push(args.output_dir, state.global_step, model) + + def on_train_end(self, args, state, control, model=None, **kwargs): + if not self.merge_at_every_checkpoint: + self._merge_and_maybe_push(args.output_dir, state.global_step, model) diff --git a/trl/mergekit_utils.py b/trl/experimental/mergekit/mergekit_utils.py similarity index 100% rename from trl/mergekit_utils.py rename to trl/experimental/mergekit/mergekit_utils.py diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 98846bf7159..fac99e9711c 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -23,7 +23,6 @@ "callbacks": [ "BEMACallback", "LogCompletionsCallback", - "MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback", "WeaveCallback", @@ -83,7 +82,6 @@ from .callbacks import ( BEMACallback, LogCompletionsCallback, - MergeModelCallback, RichProgressCallback, SyncRefModelCallback, WeaveCallback, diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 252f38ef31e..4b2c432614c 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -13,7 +13,6 @@ # limitations under the License. import logging -import os import pandas as pd import torch @@ -34,11 +33,10 @@ from transformers.utils import is_rich_available from ..data_utils import maybe_apply_chat_template -from ..import_utils import is_mergekit_available, is_weave_available -from ..mergekit_utils import MergeConfig, merge_models, upload_model_to_hf +from ..import_utils import is_weave_available from ..models.utils import unwrap_model_for_generation from .judges import BasePairwiseJudge -from .utils import get_config_model_id, log_table_to_comet_experiment +from .utils import log_table_to_comet_experiment if is_rich_available(): @@ -227,7 +225,7 @@ def on_log(self, args, state, control, logs=None, **kwargs): table.add_column("Value", justify="right") for metric, val in metrics.items(): - formatted = f"{val:.3f}" if isinstance(val, (float, int)) else str(val) + formatted = f"{val:.3f}" if isinstance(val, float | int) else str(val) table.add_row(metric, formatted) tables.append(Panel(table, border_style="cyan", padding=(0, 1))) @@ -776,67 +774,6 @@ def on_evaluate(self, args, state, control, **kwargs): self._last_logged_step = state.global_step -class MergeModelCallback(TrainerCallback): - r""" - A [`~transformers.TrainerCallback`] that merges the policy model (the model being trained) with another model based - on a merge configuration. - - Args: - merge_config ([`MergeConfig`], *optional*): - Configuration used for the merging process. If not provided, the default [`MergeConfig`] is used. - merge_at_every_checkpoint (`bool`, *optional*, defaults to `False`): - Whether to merge the model at every checkpoint. - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether to push the merged model to the Hub after merging. - - Example: - - ```python - from trl.mergekit_utils import MergeConfig - from trl import MergeModelCallback - - config = MergeConfig() - merge_callback = MergeModelCallback(config) - trainer = DPOTrainer(..., callbacks=[merge_callback]) - ``` - """ - - def __init__( - self, - merge_config: "MergeConfig | None" = None, - merge_at_every_checkpoint: bool = False, - push_to_hub: bool = False, - ): - if not is_mergekit_available(): - raise ImportError( - "MergeModelCallback requires the `mergekit` extra. To install, run `pip install mergekit`." - ) - self.merge_config = merge_config or MergeConfig() - self.merge_at_every_checkpoint = merge_at_every_checkpoint - self.push_to_hub = push_to_hub - - def _merge_and_maybe_push(self, output_dir, global_step, model): - checkpoint_path = os.path.join(output_dir, f"checkpoint-{global_step}") - self.merge_config.policy_model_path = checkpoint_path - if self.merge_config.target_model_path is None: - self.merge_config.target_model_path = get_config_model_id(model.config) - merge_path = os.path.join(checkpoint_path, "merged") - - merge_models(self.merge_config.create(), merge_path) - - if self.push_to_hub: - repo_name = f"{output_dir}_checkpoint-{global_step}_merged" - upload_model_to_hf(merge_path, repo_name) - - def on_save(self, args, state, control, model=None, **kwargs): - if self.merge_at_every_checkpoint: - self._merge_and_maybe_push(args.output_dir, state.global_step, model) - - def on_train_end(self, args, state, control, model=None, **kwargs): - if not self.merge_at_every_checkpoint: - self._merge_and_maybe_push(args.output_dir, state.global_step, model) - - class BEMACallback(TrainerCallback): # docstyle-ignore r"""