diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py
index 399c9f60cfa685..5bdb21d67323e3 100755
--- a/src/transformers/integrations/integration_utils.py
+++ b/src/transformers/integrations/integration_utils.py
@@ -26,6 +26,7 @@
import sys
import tempfile
from dataclasses import asdict, fields
+from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
@@ -726,6 +727,35 @@ def print_to_file(s):
print(model, file=f)
+class WandbLogModel(str, Enum):
+ """Enum of possible log model values in W&B."""
+
+ CHECKPOINT = "checkpoint"
+ END = "end"
+ FALSE = "false"
+
+ @property
+ def is_enabled(self) -> bool:
+ """Check if the value corresponds to a state where the `WANDB_LOG_MODEL` setting is enabled."""
+ return self in (WandbLogModel.CHECKPOINT, WandbLogModel.END)
+
+ @classmethod
+ def _missing_(cls, value: Any) -> "WandbLogModel":
+ if not isinstance(value, str):
+ raise ValueError(f"Expecting to have a string `WANDB_LOG_MODEL` setting, but got {type(value)}")
+ if value.upper() in ENV_VARS_TRUE_VALUES:
+ DeprecationWarning(
+ f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
+ "version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
+ )
+ logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
+ return WandbLogModel.END
+ logger.warning(
+ f"Received unrecognized `WANDB_LOG_MODEL` setting value={value}; so disabling `WANDB_LOG_MODEL`"
+ )
+ return WandbLogModel.FALSE
+
+
class WandbCallback(TrainerCallback):
"""
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [Weight and Biases](https://www.wandb.com/).
@@ -740,16 +770,7 @@ def __init__(self):
self._wandb = wandb
self._initialized = False
- # log model
- if os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}):
- DeprecationWarning(
- f"Setting `WANDB_LOG_MODEL` as {os.getenv('WANDB_LOG_MODEL')} is deprecated and will be removed in "
- "version 5 of transformers. Use one of `'end'` or `'checkpoint'` instead."
- )
- logger.info(f"Setting `WANDB_LOG_MODEL` from {os.getenv('WANDB_LOG_MODEL')} to `end` instead")
- self._log_model = "end"
- else:
- self._log_model = os.getenv("WANDB_LOG_MODEL", "false").lower()
+ self._log_model = WandbLogModel(os.getenv("WANDB_LOG_MODEL", "false"))
def setup(self, args, state, model, **kwargs):
"""
@@ -834,37 +855,38 @@ def setup(self, args, state, model, **kwargs):
logger.info("Could not log the number of model parameters in Weights & Biases.")
# log the initial model architecture to an artifact
- with tempfile.TemporaryDirectory() as temp_dir:
- model_name = (
- f"model-{self._wandb.run.id}"
- if (args.run_name is None or args.run_name == args.output_dir)
- else f"model-{self._wandb.run.name}"
- )
- model_artifact = self._wandb.Artifact(
- name=model_name,
- type="model",
- metadata={
- "model_config": model.config.to_dict() if hasattr(model, "config") else None,
- "num_parameters": self._wandb.config.get("model/num_parameters"),
- "initial_model": True,
- },
- )
- # add the architecture to a separate text file
- save_model_architecture_to_file(model, temp_dir)
-
- for f in Path(temp_dir).glob("*"):
- if f.is_file():
- with model_artifact.new_file(f.name, mode="wb") as fa:
- fa.write(f.read_bytes())
- self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
-
- badge_markdown = (
- f'[]({self._wandb.run.get_url()})'
- )
+ if self._log_model.is_enabled:
+ with tempfile.TemporaryDirectory() as temp_dir:
+ model_name = (
+ f"model-{self._wandb.run.id}"
+ if (args.run_name is None or args.run_name == args.output_dir)
+ else f"model-{self._wandb.run.name}"
+ )
+ model_artifact = self._wandb.Artifact(
+ name=model_name,
+ type="model",
+ metadata={
+ "model_config": model.config.to_dict() if hasattr(model, "config") else None,
+ "num_parameters": self._wandb.config.get("model/num_parameters"),
+ "initial_model": True,
+ },
+ )
+ # add the architecture to a separate text file
+ save_model_architecture_to_file(model, temp_dir)
+
+ for f in Path(temp_dir).glob("*"):
+ if f.is_file():
+ with model_artifact.new_file(f.name, mode="wb") as fa:
+ fa.write(f.read_bytes())
+ self._wandb.run.log_artifact(model_artifact, aliases=["base_model"])
+
+ badge_markdown = (
+ f'[]({self._wandb.run.get_url()})'
+ )
- modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
+ modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
def on_train_begin(self, args, state, control, model=None, **kwargs):
if self._wandb is None:
@@ -880,7 +902,7 @@ def on_train_begin(self, args, state, control, model=None, **kwargs):
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
- if self._log_model in ("end", "checkpoint") and self._initialized and state.is_world_process_zero:
+ if self._log_model.is_enabled and self._initialized and state.is_world_process_zero:
from ..trainer import Trainer
fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
@@ -938,7 +960,7 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
self._wandb.log({**non_scalar_logs, "train/global_step": state.global_step})
def on_save(self, args, state, control, **kwargs):
- if self._log_model == "checkpoint" and self._initialized and state.is_world_process_zero:
+ if self._log_model == WandbLogModel.CHECKPOINT and self._initialized and state.is_world_process_zero:
checkpoint_metadata = {
k: v
for k, v in dict(self._wandb.summary).items()