Skip to content

Commit

Permalink
[PPOTrainer / DDPOTrainer] Fix ppo & ddpo push to Hub (#1141)
Browse files Browse the repository at this point in the history
* fix ppo push to Hub

* fix also ddpo

* more tags
  • Loading branch information
younesbelkada authored Dec 26, 2023
1 parent 8f5b492 commit 3efb484
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 25 deletions.
55 changes: 43 additions & 12 deletions trl/trainer/ddpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,45 @@
# limitations under the License.

import os
import warnings
from collections import defaultdict
from concurrent import futures
from functools import wraps
from typing import Any, Callable, Optional, Tuple
from warnings import warn

import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from transformers import Trainer
from huggingface_hub import whoami

from ..models import DDPOStableDiffusionPipeline
from . import BaseTrainer, DDPOConfig
from .utils import PerPromptStatTracker, trl_sanitze_kwargs_for_tagging
from .utils import PerPromptStatTracker


logger = get_logger(__name__)


MODEL_CARD_TEMPLATE = """---
license: apache-2.0
tags:
- trl
- ddpo
- diffusers
- reinforcement-learning
- text-to-image
- stable-diffusion
---
# {model_name}
This is a diffusion model that has been fine-tuned with reinforcement learning to
guide the model outputs according to a value, function, or human feedback. The model can be used for image generation conditioned with text.
"""


class DDPOTrainer(BaseTrainer):
"""
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
Expand Down Expand Up @@ -576,15 +595,27 @@ def train(self, epochs: Optional[int] = None):
for epoch in range(self.first_epoch, epochs):
global_step = self.step(epoch, global_step)

def _save_pretrained(self, save_directory):
self.sd_pipeline.save_pretrained(save_directory)
def create_model_card(self, path: str, model_name: Optional[str] = "TRL DDPO Model") -> None:
"""Creates and saves a model card for a TRL model.
@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
Args:
path (`str`): The path to save the model card to.
model_name (`str`, *optional*): The name of the model, defaults to `TRL DDPO Model`.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)
try:
user = whoami()["name"]
# handle the offline case
except: # noqa
warnings.warn("Cannot retrieve user information assuming you are running in offline mode.")
return

if not os.path.exists(path):
os.makedirs(path)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
model_card_content = MODEL_CARD_TEMPLATE.format(model_name=model_name, model_id=f"{user}/{path}")
with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
f.write(model_card_content)

def _save_pretrained(self, save_directory):
self.sd_pipeline.save_pretrained(save_directory)
self.create_model_card(save_directory)
14 changes: 1 addition & 13 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import typing
import warnings
from contextlib import nullcontext
from functools import wraps
from typing import Callable, List, Optional, Union

import datasets
Expand All @@ -36,7 +35,6 @@
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
Trainer,
)

from ..core import (
Expand All @@ -57,7 +55,6 @@
from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
from .utils import trl_sanitze_kwargs_for_tagging


if is_deepspeed_available():
Expand All @@ -67,6 +64,7 @@
license: apache-2.0
tags:
- trl
- ppo
- transformers
- reinforcement-learning
---
Expand Down Expand Up @@ -1445,13 +1443,3 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model

@wraps(Trainer.push_to_hub)
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = trl_sanitze_kwargs_for_tagging(tag_names=self._tag_names, kwargs=kwargs)

return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)

0 comments on commit 3efb484

Please sign in to comment.