From 6912a7947cfce53e8f33495068431edfa1124f1f Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 7 Jun 2023 15:29:34 +0000 Subject: [PATCH 1/4] Enable autotag feature --- trl/trainer/ppo_config.py | 48 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index ed2305a7d9..712780fad4 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -11,21 +11,60 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import subprocess import warnings from dataclasses import dataclass, field from typing import Optional import numpy as np +import requests from ..core import flatten_dict +def autotag() -> str: + wandb_tag = "" + print("autotag feature is enabled") + try: + git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip() + wandb_tag = f"{git_tag}" + print(f"identified git tag: {git_tag}") + except subprocess.CalledProcessError: + return wandb_tag + + git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip() + try: + # if the current branch is not main, try find the PR number + git_branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).decode("ascii").strip() + if git_branch != "main": + # try finding the pull request number on github + prs = requests.get(f"https://api.github.com/search/issues?q=repo:lvwerra/trl+is:pr+{git_commit}") + if prs.status_code == 200: + prs = prs.json() + if len(prs["items"]) > 0: + pr = prs["items"][0] + pr_number = pr["number"] + wandb_tag += f",pr-{pr_number}" + print(f"identified github pull request: {pr_number}") + else: + print("current branch is main, not searching for pull request") + except Exception as e: + print(e) + + return wandb_tag + + @dataclass class PPOConfig(object): """ Configuration class for PPOTrainer """ + task_name: Optional[str] = field( + default=None, + metadata={"help": "Name of task to use - used only for tracking purposes"}, + ) model_name: Optional[str] = field( default=None, metadata={"help": "Name of model to use - used only for tracking purposes"}, @@ -119,6 +158,15 @@ def __post_init__(self): # raise error if wandb is not installed try: import wandb # noqa: F401 + + existing_wandb_tag = os.environ.get("WANDB_TAGS", "") + wandb_tag = autotag() + if len(wandb_tag) > 0: + if len(existing_wandb_tag) > 0: + os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag]) + else: + os.environ["WANDB_TAGS"] = wandb_tag + print(os.environ["WANDB_TAGS"]) except ImportError: raise ImportError( "Please install wandb to use wandb logging. You can do this by running `pip install wandb`." From d076b07c334d31a650fb98e62d62567d8fda7fb5 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 7 Jun 2023 15:43:58 +0000 Subject: [PATCH 2/4] use `logging.info` --- trl/trainer/ppo_config.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 712780fad4..aacbb55cd8 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import subprocess import warnings @@ -25,11 +26,11 @@ def autotag() -> str: wandb_tag = "" - print("autotag feature is enabled") + logging.info("autotag feature is enabled") try: git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip() wandb_tag = f"{git_tag}" - print(f"identified git tag: {git_tag}") + logging.info(f"identified git tag: {git_tag}") except subprocess.CalledProcessError: return wandb_tag @@ -46,9 +47,9 @@ def autotag() -> str: pr = prs["items"][0] pr_number = pr["number"] wandb_tag += f",pr-{pr_number}" - print(f"identified github pull request: {pr_number}") + logging.info(f"identified github pull request: {pr_number}") else: - print("current branch is main, not searching for pull request") + logging.info("current branch is main, not searching for pull request") except Exception as e: print(e) @@ -166,7 +167,7 @@ def __post_init__(self): os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag]) else: os.environ["WANDB_TAGS"] = wandb_tag - print(os.environ["WANDB_TAGS"]) + logging.info(f"the following tags will be used for wandb logging: {os.environ['WANDB_TAGS']}") except ImportError: raise ImportError( "Please install wandb to use wandb logging. You can do this by running `pip install wandb`." From bbbdca120a6707d0d6d88bd3b9fdb5b6675901ee Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 7 Jun 2023 11:48:01 -0400 Subject: [PATCH 3/4] Update trl/trainer/ppo_config.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- trl/trainer/ppo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index aacbb55cd8..f7479dd7f4 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -51,7 +51,7 @@ def autotag() -> str: else: logging.info("current branch is main, not searching for pull request") except Exception as e: - print(e) + logger.warning(f"Automatic autotag failed with the following error: {e}") return wandb_tag From 8ed4927fcd21e3af2691fab57d033834f2fd37a4 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 7 Jun 2023 11:52:53 -0400 Subject: [PATCH 4/4] Update trl/trainer/ppo_config.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- trl/trainer/ppo_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index f7479dd7f4..abb2651fb6 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -51,7 +51,7 @@ def autotag() -> str: else: logging.info("current branch is main, not searching for pull request") except Exception as e: - logger.warning(f"Automatic autotag failed with the following error: {e}") + logging.warning(f"Automatic autotag failed with the following error: {e}") return wandb_tag