diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index ed2305a7d9..abb2651fb6 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -11,21 +11,61 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging +import os +import subprocess import warnings from dataclasses import dataclass, field from typing import Optional import numpy as np +import requests from ..core import flatten_dict +def autotag() -> str: + wandb_tag = "" + logging.info("autotag feature is enabled") + try: + git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip() + wandb_tag = f"{git_tag}" + logging.info(f"identified git tag: {git_tag}") + except subprocess.CalledProcessError: + return wandb_tag + + git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip() + try: + # if the current branch is not main, try find the PR number + git_branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip() + if git_branch != "main": + # try finding the pull request number on github + prs = requests.get(f"https://api.github.com/search/issues?q=repo:lvwerra/trl+is:pr+{git_commit}") + if prs.status_code == 200: + prs = prs.json() + if len(prs["items"]) > 0: + pr = prs["items"][0] + pr_number = pr["number"] + wandb_tag += f",pr-{pr_number}" + logging.info(f"identified github pull request: {pr_number}") + else: + logging.info("current branch is main, not searching for pull request") + except Exception as e: + logging.warning(f"Automatic autotag failed with the following error: {e}") + + return wandb_tag + + @dataclass class PPOConfig(object): """ Configuration class for PPOTrainer """ + task_name: Optional[str] = field( + default=None, + metadata={"help": "Name of task to use - used only for tracking purposes"}, + ) model_name: Optional[str] = field( default=None, metadata={"help": "Name of model to use - used only for tracking purposes"}, @@ -119,6 +159,15 @@ def __post_init__(self): # raise error if wandb is not installed try: import wandb # noqa: F401 + + existing_wandb_tag = os.environ.get("WANDB_TAGS", "") + wandb_tag = autotag() + if len(wandb_tag) > 0: + if len(existing_wandb_tag) > 0: + os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag]) + else: + os.environ["WANDB_TAGS"] = wandb_tag + logging.info(f"the following tags will be used for wandb logging: {os.environ['WANDB_TAGS']}") except ImportError: raise ImportError( "Please install wandb to use wandb logging. You can do this by running `pip install wandb`."