From c00722ce0a250cd48b685380d5e1041eacdd00ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 27 Sep 2024 15:23:05 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=83=8F=20Model=20card=20for=20TRL=20(#212?= =?UTF-8?q?3)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * template and util * test for online dpo * template in package_data * template in manifest * standardize push_to_hub * wandb badge and quick start * bco * xpo * simplify `create_model_card` * cpo * kto * dpo * gkd * orpo * style * nash-md * alignprop * bco citation * citation template * cpo citation * ddpo * fix alignprop * dpo * gkd citation * kto * online dpo citation * orpo citation * citation in utils * optional citation * reward * optional trainer citation * sft * remove add_model_tags bco * Remove unnecessary code for adding model tags * Fix model tag issue and update URL format * Remove unused code for adding model tags * Add citation for XPOTrainer * Remove unused code in SFTTrainer * Add model card generation in RLOOTrainer * Remove unused import and method call in reward_trainer.py * Add model card generation * Remove unused code and update error message in ORPOTrainer class * Add import statements and create model card in IterativeSFTTrainer * Add dataset name to push_to_hub() call * Update trainer.push_to_hub() dataset names * script args * test * better doc * fix tag test * fix test tag * Add tags parameter to create_model_card method * doc * script args * Update trl/templates/model_card.md Co-authored-by: lewtun * unittest's `assertIn` instead of `assert` * Update trl/templates/model_card.md Co-authored-by: lewtun --------- Co-authored-by: lewtun --- MANIFEST.in | 3 +- examples/scripts/alignprop.py | 2 +- examples/scripts/bco.py | 2 +- examples/scripts/cpo.py | 2 +- examples/scripts/ddpo.py | 2 +- examples/scripts/dpo.py | 2 +- examples/scripts/dpo_online.py | 2 +- examples/scripts/dpo_visual.py | 2 +- examples/scripts/gkd.py | 2 +- examples/scripts/kto.py | 2 +- examples/scripts/nash_md.py | 2 +- examples/scripts/orpo.py | 2 +- examples/scripts/ppo/ppo.py | 2 +- examples/scripts/ppo/ppo_tldr.py | 2 +- examples/scripts/reward_modeling.py | 2 +- examples/scripts/rloo/rloo.py | 2 +- examples/scripts/rloo/rloo_tldr.py | 2 +- examples/scripts/sft.py | 2 +- examples/scripts/sft_vlm.py | 2 +- examples/scripts/xpo.py | 2 +- setup.py | 2 +- tests/test_dpo_trainer.py | 6 +- tests/test_sft_trainer.py | 6 +- tests/test_utils.py | 45 ++++++++++++- trl/templates/lm_model_card.md | 54 ++++++++++++++++ trl/trainer/alignprop_trainer.py | 91 +++++++++++++++------------ trl/trainer/bco_trainer.py | 59 +++++++++++++---- trl/trainer/cpo_trainer.py | 61 ++++++++++++++---- trl/trainer/ddpo_trainer.py | 94 ++++++++++++++++------------ trl/trainer/dpo_trainer.py | 62 ++++++++++++++---- trl/trainer/gkd_trainer.py | 61 +++++++++++++++++- trl/trainer/iterative_sft_trainer.py | 57 +++++++++++++---- trl/trainer/kto_trainer.py | 62 +++++++++++++----- trl/trainer/nash_md_trainer.py | 60 +++++++++++++++++- trl/trainer/online_dpo_trainer.py | 66 +++++++++++++++---- trl/trainer/orpo_trainer.py | 60 ++++++++++++++---- trl/trainer/ppov2_trainer.py | 68 ++++++++++++++++---- trl/trainer/reward_trainer.py | 60 +++++++++++++----- trl/trainer/rloo_trainer.py | 70 +++++++++++++++++---- trl/trainer/sft_trainer.py | 43 +++++++++++++ trl/trainer/utils.py | 82 +++++++++++++++++++++++- trl/trainer/xpo_trainer.py | 58 ++++++++++++++++- 42 files changed, 1023 insertions(+), 245 deletions(-) create mode 100644 trl/templates/lm_model_card.md diff --git a/MANIFEST.in b/MANIFEST.in index f0d7acb4da..26496e93f1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,4 +2,5 @@ include settings.ini include LICENSE include CONTRIBUTING.md include README.md -recursive-exclude * __pycache__ \ No newline at end of file +recursive-exclude * __pycache__ +include trl/templates/*.md \ No newline at end of file diff --git a/examples/scripts/alignprop.py b/examples/scripts/alignprop.py index df1bd61a99..1948080f4b 100644 --- a/examples/scripts/alignprop.py +++ b/examples/scripts/alignprop.py @@ -132,4 +132,4 @@ def image_outputs_logger(image_pair_data, global_step, accelerate_logger): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/bco.py b/examples/scripts/bco.py index d1f1f51ced..d00b039c21 100644 --- a/examples/scripts/bco.py +++ b/examples/scripts/bco.py @@ -164,4 +164,4 @@ def mean_pooling(model_output, attention_mask): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/cpo.py b/examples/scripts/cpo.py index 286c92acf3..341ea67cac 100644 --- a/examples/scripts/cpo.py +++ b/examples/scripts/cpo.py @@ -121,4 +121,4 @@ def process(row): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/ddpo.py b/examples/scripts/ddpo.py index 2359eb9a69..92924c51e4 100644 --- a/examples/scripts/ddpo.py +++ b/examples/scripts/ddpo.py @@ -212,4 +212,4 @@ def image_outputs_logger(image_data, global_step, accelerate_logger): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 6adf1e0f7f..5fe7ddf1ca 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -142,4 +142,4 @@ # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 72f8f0bc28..73abbcd898 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -120,4 +120,4 @@ # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/dpo_visual.py b/examples/scripts/dpo_visual.py index d47ee7a5ad..fef224921f 100644 --- a/examples/scripts/dpo_visual.py +++ b/examples/scripts/dpo_visual.py @@ -135,4 +135,4 @@ def process(row): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/gkd.py b/examples/scripts/gkd.py index cc018bb713..e79e94c85b 100644 --- a/examples/scripts/gkd.py +++ b/examples/scripts/gkd.py @@ -134,4 +134,4 @@ # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/kto.py b/examples/scripts/kto.py index 5cea2af7b2..aefdc812af 100644 --- a/examples/scripts/kto.py +++ b/examples/scripts/kto.py @@ -132,4 +132,4 @@ def format_dataset(example): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py index 185b0bceba..b9dd544103 100644 --- a/examples/scripts/nash_md.py +++ b/examples/scripts/nash_md.py @@ -123,4 +123,4 @@ # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/orpo.py b/examples/scripts/orpo.py index 8b6109ad33..521f86c129 100644 --- a/examples/scripts/orpo.py +++ b/examples/scripts/orpo.py @@ -122,4 +122,4 @@ def process(row): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index e0697ed2d8..f257816544 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -133,6 +133,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name="trl-internal-testing/descriptiveness-sentiment-trl-style") trainer.generate_completions() diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 7050d71f69..58968046fb 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -138,6 +138,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name="trl-internal-testing/tldr-preference-sft-trl-style") trainer.generate_completions() diff --git a/examples/scripts/reward_modeling.py b/examples/scripts/reward_modeling.py index 5ba483137e..3fae956119 100644 --- a/examples/scripts/reward_modeling.py +++ b/examples/scripts/reward_modeling.py @@ -130,4 +130,4 @@ # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index 8fd75d2420..5f85a9580b 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -133,6 +133,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name="trl-internal-testing/descriptiveness-sentiment-trl-style") trainer.generate_completions() diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 0f76a5af75..0a86381e0b 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -137,6 +137,6 @@ def tokenize(element): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name="trl-internal-testing/tldr-preference-sft-trl-style") trainer.generate_completions() diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index bc3158ddc1..068bdd36d7 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -106,4 +106,4 @@ # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/examples/scripts/sft_vlm.py b/examples/scripts/sft_vlm.py index f61f51575e..48db9276d2 100644 --- a/examples/scripts/sft_vlm.py +++ b/examples/scripts/sft_vlm.py @@ -129,6 +129,6 @@ def collate_fn(examples): # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) if trainer.accelerator.is_main_process: processor.push_to_hub(training_args.hub_model_id) diff --git a/examples/scripts/xpo.py b/examples/scripts/xpo.py index b811ea0ea1..235935e593 100644 --- a/examples/scripts/xpo.py +++ b/examples/scripts/xpo.py @@ -107,4 +107,4 @@ # Save and push to hub trainer.save_model(training_args.output_dir) if training_args.push_to_hub: - trainer.push_to_hub() + trainer.push_to_hub(dataset_name=script_args.dataset_name) diff --git a/setup.py b/setup.py index 0e801c7078..ffeb402edf 100644 --- a/setup.py +++ b/setup.py @@ -132,7 +132,7 @@ "console_scripts": ["trl=trl.commands.cli:main"], }, include_package_data=True, - package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*"]}, + package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md"]}, packages=find_packages(exclude={"tests"}), install_requires=REQUIRED_PKGS, extras_require=EXTRAS, diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 7d6e5c4670..4518718f1c 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -782,7 +782,8 @@ def test_dpo_lora_tags(self): peft_config=lora_config, ) - assert trainer.model.model_tags == trainer._tag_names + for tag in ["dpo", "trl"]: + self.assertIn(tag, trainer.model.model_tags) @require_peft def test_dpo_tags(self): @@ -817,7 +818,8 @@ def test_dpo_tags(self): eval_dataset=dummy_dataset["test"], ) - assert trainer.model.model_tags == trainer._tag_names + for tag in ["dpo", "trl"]: + self.assertIn(tag, trainer.model.model_tags) @require_peft def test_dpo_lora_force_use_ref(self): diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index d7fea0a81e..5b499b982c 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1055,7 +1055,8 @@ def test_peft_sft_trainer_tag(self): peft_config=peft_config, ) - assert trainer.model.model_tags == trainer._tag_names + for tag in ["sft", "trl"]: + self.assertIn(tag, trainer.model.model_tags) @require_peft def test_sft_trainer_tag(self): @@ -1080,7 +1081,8 @@ def test_sft_trainer_tag(self): eval_dataset=self.eval_dataset, ) - assert trainer.model.model_tags == trainer._tag_names + for tag in ["sft", "trl"]: + self.assertIn(tag, trainer.model.model_tags) def test_sft_trainer_only_train_packing(self): with tempfile.TemporaryDirectory() as tmp_dir: diff --git a/tests/test_utils.py b/tests/test_utils.py index 8acb554947..d23e18c841 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,7 +20,7 @@ from transformers.utils import is_peft_available from trl.trainer.model_config import ModelConfig -from trl.trainer.utils import decode_and_strip_padding, get_peft_config, pad +from trl.trainer.utils import decode_and_strip_padding, generate_model_card, get_peft_config, pad if is_peft_available(): @@ -126,3 +126,46 @@ def test_example_without_padding(self): inputs = self.tokenizer(["Hello", "Hello"], padding=False, return_tensors="pt") decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer) self.assertEqual(decoded, ["Hello", "Hello"]) + + +class TestGenerateModelCard(unittest.TestCase): + def test_full(self): + model_card = generate_model_card( + base_model="username/my_base_model", + model_name="my_model", + hub_model_id="username/my_hub_model", + dataset_name="username/my_dataset", + tags=["trl", "trainer-tag"], + wandb_url="https://wandb.ai/username/project_id/runs/abcd1234", + trainer_name="My Trainer", + trainer_citation="@article{my_trainer, ...}", + paper_title="My Paper", + paper_id="1234.56789", + ) + card_text = str(model_card) + assert "[username/my_base_model](https://huggingface.co/username/my_base_model)" in card_text + assert "my_model" in card_text + assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text + assert "datasets: username/my_dataset" in card_text + assert "](https://wandb.ai/username/project_id/runs/abcd1234)" in card_text + assert "My Trainer" in card_text + assert "```bibtex\n@article{my_trainer, ...}\n```" in card_text + assert "[My Paper](https://huggingface.co/papers/1234.56789)" in card_text + + def test_val_none(self): + model_card = generate_model_card( + base_model=None, + model_name="my_model", + hub_model_id="username/my_hub_model", + dataset_name=None, + tags=None, + wandb_url=None, + trainer_name="My Trainer", + trainer_citation=None, + paper_title=None, + paper_id=None, + ) + card_text = str(model_card) + assert "my_model" in card_text + assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text + assert "My Trainer" in card_text diff --git a/trl/templates/lm_model_card.md b/trl/templates/lm_model_card.md new file mode 100644 index 0000000000..316c5d829e --- /dev/null +++ b/trl/templates/lm_model_card.md @@ -0,0 +1,54 @@ +--- +{{ card_data }} +--- + +# Model Card for {{ model_name }} + +This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}. +It has been trained using [TRL](https://github.com/huggingface/trl). + +## Quick start + +```python +from transformers import pipeline + +question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?" +generator = pipeline("text-generation", model="{{ hub_model_id }}", device="cuda") +output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0] +print(output["generated_text"]) +``` + +## Training procedure + +{% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} + +This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. + +### Framework versions + +- TRL: {{ trl_version }} +- Transformers: {{ transformers_version }} +- Pytorch: {{ pytorch_version }} +- Datasets: {{ datasets_version }} +- Tokenizers: {{ tokenizers_version }} + +## Citations + +{% if trainer_citation %}Cite {{ trainer_name }} as: + +```bibtex +{{ trainer_citation }} +```{% endif %} + +Cite TRL as: + +```bibtex +{% raw %}@misc{vonwerra2022trl, + title = {{TRL: Transformer Reinforcement Learning}}, + author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec}, + year = 2020, + journal = {GitHub repository}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/huggingface/trl}} +}{% endraw %} +``` diff --git a/trl/trainer/alignprop_trainer.py b/trl/trainer/alignprop_trainer.py index 404e23408c..19342597da 100644 --- a/trl/trainer/alignprop_trainer.py +++ b/trl/trainer/alignprop_trainer.py @@ -12,40 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import warnings +import textwrap from collections import defaultdict -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union from warnings import warn import torch from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import whoami +from transformers import is_wandb_available from ..models import DDPOStableDiffusionPipeline from . import AlignPropConfig, BaseTrainer +from .utils import generate_model_card -logger = get_logger(__name__) - -MODEL_CARD_TEMPLATE = """--- -license: apache-2.0 -library_name: transformers -tags: -- trl -- alignprop -- diffusers -- reinforcement-learning -- text-to-image -- stable-diffusion ---- - -# {model_name} - -This is a pipeline that finetunes a diffusion model with reward backpropagation while using randomized truncation (https://huggingface.co/papers/2310.03739). The model can be used for image generation conditioned with text. +if is_wandb_available(): + import wandb -""" +logger = get_logger(__name__) class AlignPropTrainer(BaseTrainer): @@ -400,27 +386,54 @@ def train(self, epochs: Optional[int] = None): for epoch in range(self.first_epoch, epochs): global_step = self.step(epoch, global_step) - def create_model_card(self, path: str, model_name: Optional[str] = "TRL AlignProp Model") -> None: - """Creates and saves a model card for a TRL model. + def _save_pretrained(self, save_directory): + self.sd_pipeline.save_pretrained(save_directory) + self.create_model_card() + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. Args: - path (`str`): The path to save the model card to. - model_name (`str`, *optional*): The name of the model, defaults to `TRL AlignProp Model`. + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - try: - user = whoami()["name"] - # handle the offline case - except Exception: - warnings.warn("Cannot retrieve user information assuming you are running in offline mode.") + if not self.is_world_process_zero(): return - if not os.path.exists(path): - os.makedirs(path) - - model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}") - with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: - f.write(model_card_content) + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @article{prabhudesai2024aligning, + title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}}, + author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki}, + year = 2024, + eprint = {arXiv:2310.03739} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="AlignProp", + trainer_citation=citation, + paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation", + paper_id="2310.03739", + ) - def _save_pretrained(self, save_directory): - self.sd_pipeline.save_pretrained(save_directory) - self.create_model_card(save_directory) + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/bco_trainer.py b/trl/trainer/bco_trainer.py index d062ef6371..a836e3b928 100644 --- a/trl/trainer/bco_trainer.py +++ b/trl/trainer/bco_trainer.py @@ -15,11 +15,11 @@ import inspect import os import random +import textwrap import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy -from functools import wraps from operator import itemgetter from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -53,9 +53,9 @@ DPODataCollatorWithPadding, RunningMoments, disable_dropout_in_model, + generate_model_card, pad_to_length, peft_module_casting_to_bf16, - trl_sanitze_kwargs_for_tagging, ) @@ -1450,17 +1450,50 @@ def log(self, logs: Dict[str, float]) -> None: del self._stored_metrics[train_eval] return super().log(logs) - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "bco" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or None, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @article{jung2024binary, + title = {{Binary Classifier Optimization for Large Language Model Alignment}}, + author = {Seungjae Jung and Gunsoo Han and Daniel Wontae Nam and Kyoung{-}Woon On}, + year = 2024, + eprint = {arXiv:2404.04656} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="BCO", + trainer_citation=citation, + paper_title="Binary Classifier Optimization for Large Language Model Alignment", + paper_id="2404.04656", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/cpo_trainer.py b/trl/trainer/cpo_trainer.py index 22b60aa940..738f563323 100644 --- a/trl/trainer/cpo_trainer.py +++ b/trl/trainer/cpo_trainer.py @@ -14,11 +14,12 @@ # limitations under the License. import inspect +import os import random +import textwrap import warnings from collections import defaultdict from contextlib import nullcontext -from functools import wraps from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np @@ -47,9 +48,9 @@ add_bos_token_if_needed, add_eos_token_if_needed, disable_dropout_in_model, + generate_model_card, pad_to_length, peft_module_casting_to_bf16, - trl_sanitze_kwargs_for_tagging, ) @@ -971,17 +972,51 @@ def _shift_right(self, input_ids): return shifted_input_ids - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "cpo" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @inproceedings{xu2024contrastive, + title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}}, + author = {Haoran Xu and Amr Sharaf and Yunmo Chen and Weiting Tan and Lingfeng Shen and Benjamin Van Durme and Kenton Murray and Young Jin Kim}, + year = 2024, + booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=51iwkioZpn} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="CPO", + trainer_citation=citation, + paper_title="Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation", + paper_id="2401.08417", + ) + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py index 19ba5142e6..df6cd94af1 100644 --- a/trl/trainer/ddpo_trainer.py +++ b/trl/trainer/ddpo_trainer.py @@ -13,43 +13,28 @@ # limitations under the License. import os -import warnings +import textwrap from collections import defaultdict from concurrent import futures -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union from warnings import warn import torch from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import whoami +from transformers import is_wandb_available from ..models import DDPOStableDiffusionPipeline from . import BaseTrainer, DDPOConfig -from .utils import PerPromptStatTracker +from .utils import PerPromptStatTracker, generate_model_card -logger = get_logger(__name__) - -MODEL_CARD_TEMPLATE = """--- -license: apache-2.0 -library_name: transformers -tags: -- trl -- ddpo -- diffusers -- reinforcement-learning -- text-to-image -- stable-diffusion ---- +if is_wandb_available(): + import wandb -# {model_name} -This is a diffusion model that has been fine-tuned with reinforcement learning to - guide the model outputs according to a value, function, or human feedback. The model can be used for image generation conditioned with text. - -""" +logger = get_logger(__name__) class DDPOTrainer(BaseTrainer): @@ -603,27 +588,56 @@ def train(self, epochs: Optional[int] = None): for epoch in range(self.first_epoch, epochs): global_step = self.step(epoch, global_step) - def create_model_card(self, path: str, model_name: Optional[str] = "TRL DDPO Model") -> None: - """Creates and saves a model card for a TRL model. + def _save_pretrained(self, save_directory): + self.sd_pipeline.save_pretrained(save_directory) + self.create_model_card() + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. Args: - path (`str`): The path to save the model card to. - model_name (`str`, *optional*): The name of the model, defaults to `TRL DDPO Model`. + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - try: - user = whoami()["name"] - # handle the offline case - except Exception: - warnings.warn("Cannot retrieve user information assuming you are running in offline mode.") + if not self.is_world_process_zero(): return - if not os.path.exists(path): - os.makedirs(path) - - model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}") - with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f: - f.write(model_card_content) + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @inproceedings{black2024training, + title = {{Training Diffusion Models with Reinforcement Learning}}, + author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine}, + year = 2024, + booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=YCWjhGrJFD}, + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="DDPO", + trainer_citation=citation, + paper_title="Training Diffusion Models with Reinforcement Learning", + paper_id="2305.13301", + ) - def _save_pretrained(self, save_directory): - self.sd_pipeline.save_pretrained(save_directory) - self.create_model_card(save_directory) + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 3c9ff4624b..e04e4693d8 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import os import random +import textwrap import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy -from functools import wraps from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import torch @@ -53,9 +54,9 @@ add_eos_token_if_needed, cap_exp, disable_dropout_in_model, + generate_model_card, pad_to_length, peft_module_casting_to_bf16, - trl_sanitze_kwargs_for_tagging, ) @@ -1677,17 +1678,52 @@ def log(self, logs: Dict[str, float]) -> None: del self._stored_metrics[train_eval] return super().log(logs) - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "dpo" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @inproceedings{rafailov2023direct, + title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}}, + author = {Rafael Rafailov and Archit Sharma and Eric Mitchell and Christopher D. Manning and Stefano Ermon and Chelsea Finn}, + year = 2023, + booktitle = {Advances in Neural Information Processing Systems 36: Annual Conference on Neural Information Processing Systems 2023, NeurIPS 2023, New Orleans, LA, USA, December 10 - 16, 2023}, + url = {http://papers.nips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html}, + editor = {Alice Oh and Tristan Naumann and Amir Globerson and Kate Saenko and Moritz Hardt and Sergey Levine}, + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="DPO", + trainer_citation=citation, + paper_title="Direct Preference Optimization: Your Language Model is Secretly a Reward Model", + paper_id="2305.18290", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 73e4303b9c..860cc18b80 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -11,23 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import random +import textwrap import warnings from copy import deepcopy -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.nn as nn import torch.nn.functional as F from accelerate.utils import is_deepspeed_available -from transformers import AutoModelForCausalLM, GenerationConfig, PreTrainedModel +from transformers import AutoModelForCausalLM, GenerationConfig, PreTrainedModel, is_wandb_available from ..import_utils import is_liger_kernel_available from ..models import PreTrainedModelWrapper from ..models.utils import unwrap_model_for_generation from .gkd_config import GKDConfig from .sft_trainer import SFTTrainer -from .utils import DataCollatorForChatML, disable_dropout_in_model, empty_cache +from .utils import DataCollatorForChatML, disable_dropout_in_model, empty_cache, generate_model_card if is_deepspeed_available(): @@ -36,6 +38,9 @@ if is_liger_kernel_available(): from liger_kernel.transformers import AutoLigerKernelForCausalLM +if is_wandb_available(): + import wandb + class GKDTrainer(SFTTrainer): _tag_names = ["trl", "gkd"] @@ -259,3 +264,53 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper): model, *_ = deepspeed.initialize(model=model, config=config_kwargs) model.eval() return model + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @inproceedings{agarwal2024on-policy, + title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}}, + author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem}, + year = 2024, + booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=3zKtaqxLhW}, + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="GKD", + trainer_citation=citation, + paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes", + paper_id="2306.13649", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/iterative_sft_trainer.py b/trl/trainer/iterative_sft_trainer.py index 81161f66db..0e0d19f08e 100644 --- a/trl/trainer/iterative_sft_trainer.py +++ b/trl/trainer/iterative_sft_trainer.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import warnings -from functools import wraps from typing import Callable, Dict, List, Optional, Tuple, Union import torch @@ -26,18 +26,23 @@ PreTrainedTokenizerBase, Trainer, TrainingArguments, + is_wandb_available, ) from transformers.trainer_utils import EvalLoopOutput from transformers.utils import is_peft_available from ..core import PPODecorators -from .utils import trl_sanitze_kwargs_for_tagging +from .utils import generate_model_card if is_peft_available(): from peft import PeftModel +if is_wandb_available(): + import wandb + + class IterativeSFTTrainer(Trainer): """ The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization. @@ -138,6 +143,10 @@ def __init__( preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + self.create_optimizer_and_scheduler(self.args.max_steps) # prepare model, optimizer and lr_scheduler @@ -381,17 +390,39 @@ def _maybe_log_save_evaluate(self): self.log(logs) - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "iterative-sft" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="Iterative SFT", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/kto_trainer.py b/trl/trainer/kto_trainer.py index 172d53fc58..b670db9c6e 100644 --- a/trl/trainer/kto_trainer.py +++ b/trl/trainer/kto_trainer.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import os import random +import textwrap import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import deepcopy -from functools import wraps from operator import itemgetter from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -37,10 +38,10 @@ PreTrainedModel, PreTrainedTokenizerBase, Trainer, + TrainerCallback, TrainingArguments, is_wandb_available, ) -from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput, has_length from transformers.utils import is_peft_available @@ -49,9 +50,9 @@ from .utils import ( DPODataCollatorWithPadding, disable_dropout_in_model, + generate_model_card, pad_to_length, peft_module_casting_to_bf16, - trl_sanitze_kwargs_for_tagging, ) @@ -1427,17 +1428,50 @@ def log(self, logs: Dict[str, float]) -> None: del self._stored_metrics[train_eval] return super().log(logs) - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "kto" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @article{ethayarajh2024kto, + title = {{KTO: Model Alignment as Prospect Theoretic Optimization}}, + author = {Kawin Ethayarajh and Winnie Xu and Niklas Muennighoff and Dan Jurafsky and Douwe Kiela}, + year = 2024, + eprint = {arXiv:2402.01306}, + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="KTO", + trainer_citation=citation, + paper_title="KTO: Model Alignment as Prospect Theoretic Optimization", + paper_id="2402.01306", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py index d0f9ecb924..b15c0d5c59 100644 --- a/trl/trainer/nash_md_trainer.py +++ b/trl/trainer/nash_md_trainer.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import textwrap from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from datasets import Dataset, IterableDataset -from transformers import PreTrainedTokenizerBase, TrainerCallback +from transformers import PreTrainedTokenizerBase, TrainerCallback, is_wandb_available from transformers.modeling_utils import PreTrainedModel from transformers.trainer_utils import EvalPrediction from transformers.training_args import OptimizerNames @@ -29,13 +31,17 @@ from ..models.utils import unwrap_model_for_generation from .nash_md_config import NashMDConfig from .online_dpo_trainer import OnlineDPOTrainer -from .utils import empty_cache, get_reward, truncate_right +from .utils import empty_cache, generate_model_card, get_reward, truncate_right if is_apex_available(): from apex import amp +if is_wandb_available(): + import wandb + + class NashMDTrainer(OnlineDPOTrainer): r""" Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`]. @@ -381,3 +387,53 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, self.accelerator.backward(loss, **kwargs) return loss.detach() / self.args.gradient_accumulation_steps + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @inproceedings{munos2024nash, + title = {Nash Learning from Human Feedback}, + author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot}, + year = 2024, + booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024}, + publisher = {OpenReview.net}, + url = {https://openreview.net/forum?id=Y5AmNYiyCQ} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="Nash-MD", + trainer_citation=citation, + paper_title="Nash Learning from Human Feedback", + paper_id="2312.00886", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 28e285421d..f495f5af1c 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import textwrap import warnings from functools import wraps from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -31,6 +33,7 @@ Trainer, TrainerCallback, is_apex_available, + is_wandb_available, ) from transformers.modeling_utils import PreTrainedModel from transformers.trainer_utils import EvalPrediction, seed_worker @@ -46,9 +49,9 @@ DPODataCollatorWithPadding, disable_dropout_in_model, empty_cache, + generate_model_card, get_reward, prepare_deepspeed, - trl_sanitze_kwargs_for_tagging, truncate_right, ) @@ -68,6 +71,9 @@ else: IS_SAGEMAKER_MP_POST_1_10 = False +if is_wandb_available(): + import wandb + logger = logging.get_logger(__name__) @@ -237,6 +243,10 @@ def __init__( preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + self._beta = args.beta # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator @@ -538,17 +548,49 @@ def _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, igno self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "online-dpo" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @article{guo2024direct, + title = {{Direct Language Model Alignment from Online AI Feedback}}, + author = {Shangmin Guo and Biao Zhang and Tianlin Liu and Tianqi Liu and Misha Khalman and Felipe Llinares and Alexandre Ram{\'{e}} and Thomas Mesnard and Yao Zhao and Bilal Piot and Johan Ferret and Mathieu Blondel}, + year = 2024, + eprint = {arXiv:2402.04792} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="Online DPO", + trainer_citation=citation, + paper_title="Direct Language Model Alignment from Online AI Feedback", + paper_id="2402.04792", + ) + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 3e0d8a3d16..7aa991ad93 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -15,12 +15,13 @@ # limitations under the License. import inspect +import os import random +import textwrap import warnings from collections import defaultdict from contextlib import nullcontext from copy import deepcopy -from functools import wraps from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union import numpy as np @@ -52,9 +53,9 @@ add_bos_token_if_needed, add_eos_token_if_needed, disable_dropout_in_model, + generate_model_card, pad_to_length, peft_module_casting_to_bf16, - trl_sanitze_kwargs_for_tagging, ) @@ -993,17 +994,50 @@ def _shift_right(self, input_ids): return shifted_input_ids - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "orpo" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @article{hong2024orpo, + title = {{ORPO: Monolithic Preference Optimization without Reference Model}}, + author = {Jiwoo Hong and Noah Lee and James Thorne}, + year = 2024, + eprint = {arXiv:2403.07691} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="ORPO", + trainer_citation=citation, + paper_title="ORPO: Monolithic Preference Optimization without Reference Model", + paper_id="2403.07691", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index 44e9b425ef..e6fb9d70ef 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -15,9 +15,9 @@ import gc import math import os +import textwrap import time from collections import defaultdict -from functools import wraps from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -36,6 +36,7 @@ Trainer, TrainerCallback, TrainerControl, + is_wandb_available, ) from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK @@ -56,7 +57,11 @@ truncate_response, ) from .ppov2_config import PPOv2Config -from .utils import trl_sanitze_kwargs_for_tagging +from .utils import generate_model_card + + +if is_wandb_available(): + import wandb INVALID_LOGPROB = 1.0 @@ -199,6 +204,10 @@ def __init__( if self.args.should_save: os.makedirs(self.args.output_dir, exist_ok=True) + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + ######### ### setup dataloader ######### @@ -639,17 +648,50 @@ def generate_completions(self, sampling: bool = False): if wandb.run is not None: wandb.log({"completions": wandb.Table(dataframe=df)}) - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "ppo" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @article{mziegler2019fine-tuning, + title = {{Fine-Tuning Language Models from Human Preferences}}, + author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving}, + year = 2019, + eprint = {arXiv:1909.08593} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="PPO", + trainer_citation=citation, + paper_title="Fine-Tuning Language Models from Human Preferences", + paper_id="1909.08593", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index 932938aa06..c1cf590372 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import os import warnings from collections import defaultdict from dataclasses import FrozenInstanceError, replace -from functools import wraps from typing import Any, Callable, Dict, List, Optional, Tuple, Union import pandas as pd @@ -24,7 +24,14 @@ from accelerate import PartialState from accelerate.utils import gather_object from datasets import Dataset -from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer, TrainingArguments +from transformers import ( + DataCollator, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainingArguments, + is_wandb_available, +) from transformers.trainer_callback import TrainerCallback from transformers.trainer_pt_utils import nested_detach from transformers.trainer_utils import EvalPrediction @@ -36,14 +43,17 @@ RewardDataCollatorWithPadding, compute_accuracy, decode_and_strip_padding, + generate_model_card, print_rich_table, - trl_sanitze_kwargs_for_tagging, ) if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training +if is_wandb_available(): + import wandb + def _tokenize(batch: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizerBase") -> Dict[str, List[Any]]: """Tokenize a batch from a reward modelling dataset.""" @@ -345,17 +355,39 @@ def visualize_samples(self, num_print_samples: int): if wandb.run is not None: wandb.log({"completions": wandb.Table(dataframe=df)}) - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "reward-trainer" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="Reward", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 03b494e9a5..516eeb4dfb 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -15,9 +15,9 @@ import gc import math import os +import textwrap import time from collections import defaultdict -from functools import wraps from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -36,6 +36,7 @@ Trainer, TrainerCallback, TrainerControl, + is_wandb_available, ) from transformers.integrations import get_reporting_integration_callbacks from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK @@ -55,9 +56,12 @@ truncate_response, ) from .rloo_config import RLOOConfig -from .utils import trl_sanitze_kwargs_for_tagging +from .utils import generate_model_card +if is_wandb_available(): + import wandb + INVALID_LOGPROB = 1.0 @@ -172,6 +176,10 @@ def __init__( os.makedirs(self.args.output_dir, exist_ok=True) self.backup_model = None + # Add tags for models that have been loaded with the correct transformers version + if hasattr(self.model, "add_model_tags"): + self.model.add_model_tags(self._tag_names) + ######### ### setup dataloader ######### @@ -530,17 +538,53 @@ def generate_completions(self, sampling: bool = False): if wandb.run is not None: wandb.log({"completions": wandb.Table(dataframe=df)}) - @wraps(Trainer.push_to_hub) - def push_to_hub( + def create_model_card( self, - commit_message: Optional[str] = "End of training", - blocking: bool = True, - **kwargs, - ) -> str: + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): """ - Overwrite the `push_to_hub` method in order to force-add the tag "rloo" when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - Unlike the parent class, we don't use the `token` argument to mitigate security risks. + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. """ - kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs) - return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @inproceedings{ahmadian2024back, + title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}}, + author = {Arash Ahmadian and Chris Cremer and Matthias Gall{\'{e}} and Marzieh Fadaee and Julia Kreutzer and Olivier Pietquin and Ahmet {\"{U}}st{\"{u}}n and Sara Hooker}, + year = 2024, + booktitle = {Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), {ACL} 2024, Bangkok, Thailand, August 11-16, 2024}, + publisher = {Association for Computational Linguistics}, + pages = {12248--12267}, + editor = {Lun{-}Wei Ku and Andre Martins and Vivek Srikumar}, + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="RLOO", + trainer_citation=citation, + paper_title="Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs", + paper_id="2402.14740", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 49bea851e7..f105559f2a 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -13,6 +13,7 @@ # limitations under the License. import dataclasses import inspect +import os import warnings from functools import wraps from typing import Callable, Dict, List, Optional, Tuple, Union @@ -33,6 +34,7 @@ PreTrainedModel, PreTrainedTokenizerBase, Trainer, + is_wandb_available, ) from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalPrediction @@ -44,6 +46,7 @@ from .utils import ( ConstantLengthDataset, DataCollatorForCompletionOnlyLM, + generate_model_card, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, ) @@ -55,6 +58,9 @@ if is_liger_kernel_available(): from liger_kernel.transformers import AutoLigerKernelForCausalLM +if is_wandb_available(): + import wandb + class SFTTrainer(Trainer): r""" @@ -611,3 +617,40 @@ def data_generator(constant_length_iterator): raise ValueError( "You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`." ) + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="SFT", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md")) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 772286605e..b514b228ea 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -12,18 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. import dataclasses +import importlib.resources as pkg_resources import json import random import warnings from collections import deque from dataclasses import dataclass +from importlib.metadata import version from typing import Any, Dict, List, Literal, Optional, Tuple, Union import numpy as np import pandas as pd import torch -from accelerate import Accelerator -from accelerate.state import AcceleratorState, PartialState +import torch.utils.data +from accelerate import Accelerator, PartialState +from accelerate.state import AcceleratorState +from huggingface_hub import ModelCard, ModelCardData from rich.console import Console from rich.table import Table from torch.nn.utils.rnn import pad_sequence @@ -1366,3 +1370,77 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize """ decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False) return [d.replace(tokenizer.pad_token, "") for d in decoded] + + +def generate_model_card( + base_model: Optional[str], + model_name: str, + hub_model_id: str, + dataset_name: Optional[str], + tags: Union[str, List[str], None], + wandb_url: Optional[str], + trainer_name: str, + trainer_citation: Optional[str], + paper_title: Optional[str], + paper_id: Optional[str], +) -> ModelCard: + """ + Generate a `ModelCard` from a template. + + Args: + base_model (`str` or `None`): + Base model name. + model_name (`str`): + Model name. + hub_model_id (`str`): + Hub model ID as `username/model_id`. + dataset_name (`str` or `None`): + Dataset name. + tags (`str`, `List[str]`, or `None`): + Tags. + wandb_url (`str` or `None`): + Weights & Biases run URL. + trainer_name (`str`): + Trainer name. + trainer_citation (`str` or `None`): + Trainer citation as a BibTeX entry. + paper_title (`str` or `None`): + Paper title. + paper_id (`str` or `None`): + ArXiv paper ID as `YYMM.NNNNN`. + + Returns: + `ModelCard`: + A ModelCard object. + """ + if tags is None: + tags = [] + elif isinstance(tags, str): + tags = [tags] + card_data = ModelCardData( + base_model=base_model, + datasets=dataset_name, + library_name="transformers", + licence="license", + model_name=model_name, + tags=["generated_from_trainer", *tags], + ) + card = ModelCard.from_template( + card_data, + template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")), + base_model=base_model, + model_name=model_name, + hub_model_id=hub_model_id, + dataset_name=dataset_name, + wandb_url=wandb_url, + trainer_name=trainer_name, + trainer_citation=trainer_citation, + paper_title=paper_title, + paper_id=paper_id, + trl_version=version("trl"), + transformers_version=version("transformers"), + pytorch_version=version("torch"), + datasets_version=version("datasets"), + tokenizers_version=version("tokenizers"), + ) + return card diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index d57215e850..4fe778fb64 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import textwrap from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from datasets import Dataset, IterableDataset -from transformers import PreTrainedTokenizerBase, TrainerCallback, is_apex_available +from transformers import PreTrainedTokenizerBase, TrainerCallback, is_apex_available, is_wandb_available from transformers.modeling_utils import PreTrainedModel from transformers.trainer_utils import EvalPrediction from transformers.training_args import OptimizerNames @@ -26,7 +28,7 @@ from ..data_utils import maybe_apply_chat_template from ..models.utils import unwrap_model_for_generation from .online_dpo_trainer import OnlineDPOTrainer -from .utils import empty_cache, get_reward, truncate_right +from .utils import empty_cache, generate_model_card, get_reward, truncate_right from .xpo_config import XPOConfig @@ -34,6 +36,10 @@ from apex import amp +if is_wandb_available(): + import wandb + + class XPOTrainer(OnlineDPOTrainer): r""" Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`]. @@ -437,3 +443,51 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, self.accelerator.backward(loss, **kwargs) return loss.detach() / self.args.gradient_accumulation_steps + + def create_model_card( + self, + model_name: Optional[str] = None, + dataset_name: Optional[str] = None, + tags: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + model_name (`str`, *optional*, defaults to `None`): + The name of the model. + dataset_name (`str`, *optional*, defaults to `None`): + The name of the dataset used for training. + tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`): + Tags to be associated with the model card. + """ + if not self.is_world_process_zero(): + return + + if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): + base_model = self.model.config._name_or_path + else: + base_model = None + + citation = textwrap.dedent("""\ + @article{jung2024binary, + title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}}, + author = {Tengyang Xie and Dylan J. Foster and Akshay Krishnamurthy and Corby Rosset and Ahmed Awadallah and Alexander Rakhlin}, + year = 2024, + eprint = {arXiv:2405.21046} + }""") + + model_card = generate_model_card( + base_model=base_model, + model_name=model_name, + hub_model_id=self.hub_model_id, + dataset_name=dataset_name, + tags=tags, + wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, + trainer_name="XPO", + trainer_citation=citation, + paper_title="Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF", + paper_id="2405.21046", + ) + + model_card.save(os.path.join(self.args.output_dir, "README.md"))