From 3877dabb0af920837ed48a81352aa51c4ed43fa2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 22 Dec 2023 13:18:34 +0000 Subject: [PATCH 1/2] multi-tags support for tagging --- trl/trainer/ddpo_trainer.py | 4 ++-- trl/trainer/dpo_trainer.py | 4 ++-- trl/trainer/ppo_trainer.py | 4 ++-- trl/trainer/sft_trainer.py | 4 ++-- trl/trainer/utils.py | 12 ++++++++---- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/trl/trainer/ddpo_trainer.py b/trl/trainer/ddpo_trainer.py index b6cd432b18..84b8cfce5b 100644 --- a/trl/trainer/ddpo_trainer.py +++ b/trl/trainer/ddpo_trainer.py @@ -48,7 +48,7 @@ class DDPOTrainer(BaseTrainer): **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images """ - _tag_name = "trl-ddpo" + _tag_names = ["trl", "ddpo"] def __init__( self, @@ -585,6 +585,6 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs) return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 97b126ddb2..b96f8e0c8a 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -122,7 +122,7 @@ class DPOTrainer(Trainer): Dict of Optional kwargs to pass when instantiating the ref model from a string """ - _tag_name = "trl-dpo" + _tag_names = ["trl", "dpo"] def __init__( self, @@ -1144,6 +1144,6 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs) return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index bd466226f3..d88e8f2729 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -140,7 +140,7 @@ class PPOTrainer(BaseTrainer): **lr_scheduler** (`torch.optim.lr_scheduler`, *optional*) -- Learning rate scheduler to be used for training. """ - _tag_name = "trl-ppo" + _tag_names = ["trl", "ppo"] def __init__( self, @@ -1452,6 +1452,6 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs) return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 78243758cc..259e70b8f0 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -115,7 +115,7 @@ class SFTTrainer(Trainer): dataset_kwargs: (`Optional[Dict]`, *optional*): Dict of Optional kwargs to pass when creating packed or non-packed datasets """ - _tag_name = "trl-sft" + _tag_names = ["trl", "sft"] def __init__( self, @@ -334,7 +334,7 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. """ - kwargs = trl_sanitze_kwargs_for_tagging(tag_name=self._tag_name, kwargs=kwargs) + kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs) return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 33d6c6f092..23ee648fd7 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -639,12 +639,16 @@ def peft_module_casting_to_bf16(model): module = module.to(torch.bfloat16) -def trl_sanitze_kwargs_for_tagging(tag_name, kwargs=None): +def trl_sanitze_kwargs_for_tagging(tag_names, kwargs=None): + if isinstance(tag_names, str): + tag_names = [tag_names] + if kwargs is not None: if "tags" not in kwargs: - kwargs["tags"] = [tag_name] + kwargs["tags"] = tag_names elif "tags" in kwargs and isinstance(kwargs["tags"], list): - kwargs["tags"].append(tag_name) + kwargs["tags"].extend(tag_names) elif "tags" in kwargs and isinstance(kwargs["tags"], str): - kwargs["tags"] = [kwargs["tags"], tag_name] + tag_names.appendkwargs["tags"]() + kwargs["tags"] = tag_names return kwargs From 82a5c5839174cdea8d6b3cb39eac1efa513d2af2 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Fri, 22 Dec 2023 13:20:08 +0000 Subject: [PATCH 2/2] oops --- trl/trainer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 23ee648fd7..b6e4cce2d3 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -649,6 +649,6 @@ def trl_sanitze_kwargs_for_tagging(tag_names, kwargs=None): elif "tags" in kwargs and isinstance(kwargs["tags"], list): kwargs["tags"].extend(tag_names) elif "tags" in kwargs and isinstance(kwargs["tags"], str): - tag_names.appendkwargs["tags"]() + tag_names.append(kwargs["tags"]) kwargs["tags"] = tag_names return kwargs