Skip to content

Commit

Permalink
🃏 Model card for TRL (#2123)
Browse files Browse the repository at this point in the history
* template and util

* test for online dpo

* template in package_data

* template in manifest

* standardize push_to_hub

* wandb badge and quick start

* bco

* xpo

* simplify `create_model_card`

* cpo

* kto

* dpo

* gkd

* orpo

* style

* nash-md

* alignprop

* bco citation

* citation template

* cpo citation

* ddpo

* fix alignprop

* dpo

* gkd citation

* kto

* online dpo citation

* orpo citation

* citation in utils

* optional citation

* reward

* optional trainer citation

* sft

* remove add_model_tags bco

* Remove unnecessary code for adding model tags

* Fix model tag issue and update URL format

* Remove unused code for adding model tags

* Add citation for XPOTrainer

* Remove unused code in SFTTrainer

* Add model card generation in RLOOTrainer

* Remove unused import and method call in reward_trainer.py

* Add model card generation

* Remove unused code and update error message in ORPOTrainer class

* Add import statements and create model card in IterativeSFTTrainer

* Add dataset name to push_to_hub() call

* Update trainer.push_to_hub() dataset names

* script args

* test

* better doc

* fix tag test

* fix test tag

* Add tags parameter to create_model_card method

* doc

* script args

* Update trl/templates/model_card.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* unittest's `assertIn` instead of `assert`

* Update trl/templates/model_card.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
  • Loading branch information
qgallouedec and lewtun authored Sep 27, 2024
1 parent 124189c commit c00722c
Show file tree
Hide file tree
Showing 42 changed files with 1,023 additions and 245 deletions.
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
recursive-exclude * __pycache__
include trl/templates/*.md
2 changes: 1 addition & 1 deletion examples/scripts/alignprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ def image_outputs_logger(image_pair_data, global_step, accelerate_logger):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/bco.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,4 @@ def mean_pooling(model_output, attention_mask):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ def process(row):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/ddpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,4 +212,4 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,4 @@
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,4 @@
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/dpo_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ def process(row):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,4 @@
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,4 @@ def format_dataset(example):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/nash_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,4 @@
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ def process(row):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,6 @@ def tokenize(element):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name="trl-internal-testing/descriptiveness-sentiment-trl-style")

trainer.generate_completions()
2 changes: 1 addition & 1 deletion examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,6 @@ def tokenize(element):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name="trl-internal-testing/tldr-preference-sft-trl-style")

trainer.generate_completions()
2 changes: 1 addition & 1 deletion examples/scripts/reward_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,4 @@
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/rloo/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,6 @@ def tokenize(element):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name="trl-internal-testing/descriptiveness-sentiment-trl-style")

trainer.generate_completions()
2 changes: 1 addition & 1 deletion examples/scripts/rloo/rloo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,6 @@ def tokenize(element):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name="trl-internal-testing/tldr-preference-sft-trl-style")

trainer.generate_completions()
2 changes: 1 addition & 1 deletion examples/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,4 @@
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion examples/scripts/sft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,6 @@ def collate_fn(examples):
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if trainer.accelerator.is_main_process:
processor.push_to_hub(training_args.hub_model_id)
2 changes: 1 addition & 1 deletion examples/scripts/xpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,4 @@
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub()
trainer.push_to_hub(dataset_name=script_args.dataset_name)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"console_scripts": ["trl=trl.commands.cli:main"],
},
include_package_data=True,
package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*"]},
package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*", "templates/*.md"]},
packages=find_packages(exclude={"tests"}),
install_requires=REQUIRED_PKGS,
extras_require=EXTRAS,
Expand Down
6 changes: 4 additions & 2 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,8 @@ def test_dpo_lora_tags(self):
peft_config=lora_config,
)

assert trainer.model.model_tags == trainer._tag_names
for tag in ["dpo", "trl"]:
self.assertIn(tag, trainer.model.model_tags)

@require_peft
def test_dpo_tags(self):
Expand Down Expand Up @@ -817,7 +818,8 @@ def test_dpo_tags(self):
eval_dataset=dummy_dataset["test"],
)

assert trainer.model.model_tags == trainer._tag_names
for tag in ["dpo", "trl"]:
self.assertIn(tag, trainer.model.model_tags)

@require_peft
def test_dpo_lora_force_use_ref(self):
Expand Down
6 changes: 4 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,8 @@ def test_peft_sft_trainer_tag(self):
peft_config=peft_config,
)

assert trainer.model.model_tags == trainer._tag_names
for tag in ["sft", "trl"]:
self.assertIn(tag, trainer.model.model_tags)

@require_peft
def test_sft_trainer_tag(self):
Expand All @@ -1080,7 +1081,8 @@ def test_sft_trainer_tag(self):
eval_dataset=self.eval_dataset,
)

assert trainer.model.model_tags == trainer._tag_names
for tag in ["sft", "trl"]:
self.assertIn(tag, trainer.model.model_tags)

def test_sft_trainer_only_train_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
45 changes: 44 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from transformers.utils import is_peft_available

from trl.trainer.model_config import ModelConfig
from trl.trainer.utils import decode_and_strip_padding, get_peft_config, pad
from trl.trainer.utils import decode_and_strip_padding, generate_model_card, get_peft_config, pad


if is_peft_available():
Expand Down Expand Up @@ -126,3 +126,46 @@ def test_example_without_padding(self):
inputs = self.tokenizer(["Hello", "Hello"], padding=False, return_tensors="pt")
decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer)
self.assertEqual(decoded, ["Hello", "Hello"])


class TestGenerateModelCard(unittest.TestCase):
def test_full(self):
model_card = generate_model_card(
base_model="username/my_base_model",
model_name="my_model",
hub_model_id="username/my_hub_model",
dataset_name="username/my_dataset",
tags=["trl", "trainer-tag"],
wandb_url="https://wandb.ai/username/project_id/runs/abcd1234",
trainer_name="My Trainer",
trainer_citation="@article{my_trainer, ...}",
paper_title="My Paper",
paper_id="1234.56789",
)
card_text = str(model_card)
assert "[username/my_base_model](https://huggingface.co/username/my_base_model)" in card_text
assert "my_model" in card_text
assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
assert "datasets: username/my_dataset" in card_text
assert "](https://wandb.ai/username/project_id/runs/abcd1234)" in card_text
assert "My Trainer" in card_text
assert "```bibtex\n@article{my_trainer, ...}\n```" in card_text
assert "[My Paper](https://huggingface.co/papers/1234.56789)" in card_text

def test_val_none(self):
model_card = generate_model_card(
base_model=None,
model_name="my_model",
hub_model_id="username/my_hub_model",
dataset_name=None,
tags=None,
wandb_url=None,
trainer_name="My Trainer",
trainer_citation=None,
paper_title=None,
paper_id=None,
)
card_text = str(model_card)
assert "my_model" in card_text
assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text
assert "My Trainer" in card_text
54 changes: 54 additions & 0 deletions trl/templates/lm_model_card.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
---
{{ card_data }}
---

# Model Card for {{ model_name }}

This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}.
It has been trained using [TRL](https://github.com/huggingface/trl).

## Quick start

```python
from transformers import pipeline

question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
generator = pipeline("text-generation", model="{{ hub_model_id }}", device="cuda")
output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
print(output["generated_text"])
```

## Training procedure

{% if wandb_url %}[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>]({{ wandb_url }}){% endif %}

This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.

### Framework versions

- TRL: {{ trl_version }}
- Transformers: {{ transformers_version }}
- Pytorch: {{ pytorch_version }}
- Datasets: {{ datasets_version }}
- Tokenizers: {{ tokenizers_version }}

## Citations

{% if trainer_citation %}Cite {{ trainer_name }} as:

```bibtex
{{ trainer_citation }}
```{% endif %}
Cite TRL as:
```bibtex
{% raw %}@misc{vonwerra2022trl,
title = {{TRL: Transformer Reinforcement Learning}},
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
year = 2020,
journal = {GitHub repository},
publisher = {GitHub},
howpublished = {\url{https://github.com/huggingface/trl}}
}{% endraw %}
```
Loading

0 comments on commit c00722c

Please sign in to comment.