Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🃏 Model card: "unsloth" tag #2173

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_val_none(self):
model_name="my_model",
hub_model_id="username/my_hub_model",
dataset_name=None,
tags=None,
tags=[],
wandb_url=None,
trainer_name="My Trainer",
trainer_citation=None,
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/alignprop_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@article{prabhudesai2024aligning,
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,6 +1483,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@article{jung2024binary,
title = {{Binary Classifier Optimization for Large Language Model Alignment}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@inproceedings{xu2024contrastive,
title = {{Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@inproceedings{black2024training,
title = {{Training Diffusion Models with Reinforcement Learning}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,6 +1727,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@inproceedings{rafailov2023direct,
title = {{Direct Preference Optimization: Your Language Model is Secretly a Reward Model}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@inproceedings{agarwal2024on-policy,
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/iterative_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@article{ethayarajh2024kto,
title = {{KTO: Model Alignment as Prospect Theoretic Optimization}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/nash_md_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@inproceedings{munos2024nash,
title = {Nash Learning from Human Feedback},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@article{guo2024direct,
title = {{Direct Language Model Alignment from Online AI Feedback}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@article{hong2024orpo,
title = {{ORPO: Monolithic Preference Optimization without Reference Model}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@article{mziegler2019fine-tuning,
title = {{Fine-Tuning Language Models from Human Preferences}},
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@inproceedings{ahmadian2024back,
title = {{Back to Basics: Revisiting REINFORCE-Style Optimization for Learning from Human Feedback in LLMs}},
Expand Down
24 changes: 7 additions & 17 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import inspect
import os
import warnings
from functools import wraps
from typing import Callable, Dict, List, Optional, Tuple, Union

import datasets
Expand Down Expand Up @@ -52,7 +51,6 @@
DataCollatorForCompletionOnlyLM,
generate_model_card,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
)


Expand Down Expand Up @@ -435,21 +433,6 @@ def make_inputs_require_grad(module, input, output):
elif self.args.max_steps == -1 and args.packing:
self.train_dataset.infinite = False

@wraps(Trainer.push_to_hub)
def push_to_hub(
self,
commit_message: Optional[str] = "End of training",
blocking: bool = True,
**kwargs,
) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
Unlike the parent class, we don't use the `token` argument to mitigate security risks.
"""
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)

Comment on lines -438 to -452
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forget to remove in #2123

def _prepare_dataset(
self,
dataset,
Expand Down Expand Up @@ -639,6 +622,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

model_card = generate_model_card(
base_model=base_model,
model_name=model_name,
Expand Down
8 changes: 2 additions & 6 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ def generate_model_card(
model_name: str,
hub_model_id: str,
dataset_name: Optional[str],
tags: Union[str, List[str], None],
tags: List[str],
wandb_url: Optional[str],
trainer_name: str,
trainer_citation: Optional[str] = None,
Expand All @@ -1407,7 +1407,7 @@ def generate_model_card(
Hub model ID as `username/model_id`.
dataset_name (`str` or `None`):
Dataset name.
tags (`str`, `List[str]`, or `None`):
tags (`List[str]`):
Tags.
wandb_url (`str` or `None`):
Weights & Biases run URL.
Expand All @@ -1424,10 +1424,6 @@ def generate_model_card(
`ModelCard`:
A ModelCard object.
"""
if tags is None:
tags = []
elif isinstance(tags, str):
tags = [tags]
card_data = ModelCardData(
base_model=base_model,
datasets=dataset_name,
Expand Down
7 changes: 7 additions & 0 deletions trl/trainer/xpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,13 @@ def create_model_card(
else:
base_model = None

tags = tags or []
if isinstance(tags, str):
tags = [tags]

if hasattr(self.model.config, "unsloth_version"):
tags.append("unsloth")

citation = textwrap.dedent("""\
@article{jung2024binary,
title = {{Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF}},
Expand Down
Loading