Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[xxxTrainer] multi-tags support for tagging #1133

Merged
merged 2 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions trl/trainer/ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class DDPOTrainer(BaseTrainer):
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
"""

_tag_name = "trl-ddpo"
_tag_names = ["trl", "ddpo"]

def __init__(
self,
Expand Down Expand Up @@ -585,6 +585,6 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
4 changes: 2 additions & 2 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class DPOTrainer(Trainer):
Dict of Optional kwargs to pass when instantiating the ref model from a string
"""

_tag_name = "trl-dpo"
_tag_names = ["trl", "dpo"]

def __init__(
self,
Expand Down Expand Up @@ -1144,6 +1144,6 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
4 changes: 2 additions & 2 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class PPOTrainer(BaseTrainer):
**lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training.
"""

_tag_name = "trl-ppo"
_tag_names = ["trl", "ppo"]

def __init__(
self,
Expand Down Expand Up @@ -1452,6 +1452,6 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
4 changes: 2 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class SFTTrainer(Trainer):
dataset_kwargs: (`Optional[Dict]`, *optional*):
Dict of Optional kwargs to pass when creating packed or non-packed datasets
"""
_tag_name = "trl-sft"
_tag_names = ["trl", "sft"]

def __init__(
self,
Expand Down Expand Up @@ -334,7 +334,7 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs)
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)

Expand Down
12 changes: 8 additions & 4 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,16 @@ def peft_module_casting_to_bf16(model):
module = module.to(torch.bfloat16)


def trl_sanitze_kwargs_for_tagging(tag_name, kwargs=None):
def trl_sanitze_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]

if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = [tag_name]
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].append(tag_name)
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"], tag_name]
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
Loading