diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 399c9f60cfa685..5bdb21d67323e3 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -26,6 +26,7 @@ import sys import tempfile from dataclasses import asdict, fields +from enum import Enum from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union @@ -726,6 +727,35 @@ def print_to_file(s): print(model, file=f) +class WandbLogModel(str, Enum): + """Enum of possible log model values in W&B.""" + + CHECKPOINT = "checkpoint" + END = "end" + FALSE = "false" + + @property + def is_enabled(self) -> bool: + """Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled.""" + return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END) + + @classmethod + def _missing_(cls, value: Any) -> "WandbLogModel": + if not isinstance(value, str): + raise ValueError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}") + if value.upper() in ENV_VARS_TRUE_VALUES: + DeprecationWarning( + f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in " + "version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead." + ) + logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead") + return WandbLogModel.END + logger.warning( + f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`" + ) + return WandbLogModel.FALSE + + class WandbCallback(TrainerCallback): """ A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/). @@ -740,16 +770,7 @@ def __init__(self): self._wandb = wandb self._initialized = False - # log model - if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}): - DeprecationWarning( - f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in " - "version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead." - ) - logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead") - self._log_model = "end" - else: - self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower() + self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false")) def setup(self, args, state, model, **kwargs): """ @@ -834,37 +855,38 @@ def setup(self, args, state, model, **kwargs): logger.info("Could not log the number of model parameters in Weights & Biases.") # log the initial model architecture to an artifact - with tempfile.TemporaryDirectory() as temp_dir: - model_name = ( - f"model-{self._wandb.run.id}" - if (args.run_name is None or args.run_name == args.output_dir) - else f"model-{self._wandb.run.name}" - ) - model_artifact = self._wandb.Artifact( - name=model_name, - type="model", - metadata={ - "model_config": model.config.to_dict() if hasattr(model, "config") else None, - "num_parameters": self._wandb.config.get("model/num_parameters"), - "initial_model": True, - }, - ) - # add the architecture to a separate text file - save_model_architecture_to_file(model, temp_dir) - - for f in Path(temp_dir).glob("*"): - if f.is_file(): - with model_artifact.new_file(f.name, mode="wb") as fa: - fa.write(f.read_bytes()) - self._wandb.run.log_artifact(model_artifact, aliases=["base_model"]) - - badge_markdown = ( - f'[Visualize in Weights & Biases]({self._wandb.run.get_url()})' - ) + if self._log_model.is_enabled: + with tempfile.TemporaryDirectory() as temp_dir: + model_name = ( + f"model-{self._wandb.run.id}" + if (args.run_name is None or args.run_name == args.output_dir) + else f"model-{self._wandb.run.name}" + ) + model_artifact = self._wandb.Artifact( + name=model_name, + type="model", + metadata={ + "model_config": model.config.to_dict() if hasattr(model, "config") else None, + "num_parameters": self._wandb.config.get("model/num_parameters"), + "initial_model": True, + }, + ) + # add the architecture to a separate text file + save_model_architecture_to_file(model, temp_dir) + + for f in Path(temp_dir).glob("*"): + if f.is_file(): + with model_artifact.new_file(f.name, mode="wb") as fa: + fa.write(f.read_bytes()) + self._wandb.run.log_artifact(model_artifact, aliases=["base_model"]) + + badge_markdown = ( + f'[Visualize in Weights & Biases]({self._wandb.run.get_url()})' + ) - modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" + modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" def on_train_begin(self, args, state, control, model=None, **kwargs): if self._wandb is None: @@ -880,7 +902,7 @@ def on_train_begin(self, args, state, control, model=None, **kwargs): def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): if self._wandb is None: return - if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero: + if self._log_model.is_enabled and self._initialized and state.is_world_process_zero: from ..trainer import Trainer fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer) @@ -938,7 +960,7 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs): self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step}) def on_save(self, args, state, control, **kwargs): - if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero: + if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero: checkpoint_metadata = { k: v for k, v in dict(self._wandb.summary).items()