Skip to content

Commit

Permalink
feat(wandb): save model as artifact (#8119)
Browse files Browse the repository at this point in the history
* feat(wandb): log artifacts

* fix: typo

* feat(wandb): ensure name is allowed

* feat(wandb): log artifact

* feat(wandb): saving logic

* style: improve formatting

* fix: unrelated typo

* feat: use a fake trainer

* fix: simplify

* feat(wandb): log model files as artifact

* style: fix style

* docs(wandb): correct description

* feat: unpack model + allow env Truethy values

* feat: TrainerCallback can access tokenizer

* style: fix style

* feat(wandb): log more interesting metadata

* feat: unpack tokenizer

* feat(wandb): metadata with load_best_model_at_end

* feat(wandb): more robust metadata

* style(wandb): fix formatting
  • Loading branch information
borisdayma authored Jan 5, 2021
1 parent 143289d commit 30fa0b7
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
41 changes: 40 additions & 1 deletion src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
Integrations with other Python libraries.
"""
import math
import numbers
import os
import re
import tempfile
from pathlib import Path

from .file_utils import ENV_VARS_TRUE_VALUES
from .trainer_utils import EvaluationStrategy
from .utils import logging

Expand Down Expand Up @@ -369,6 +374,8 @@ def setup(self, args, state, model, reinit, **kwargs):
<https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Environment:
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to log model as artifact at the end of training.
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters.
Expand Down Expand Up @@ -407,12 +414,44 @@ def setup(self, args, state, model, reinit, **kwargs):
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))

# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})

def on_train_begin(self, args, state, control, model=None, **kwargs):
hp_search = state.is_hyper_param_search
if not self._initialized or hp_search:
print(args.run_name)
self.setup(args, state, model, reinit=hp_search, **kwargs)

def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
# commit last step
wandb.log({})
if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer

fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
# use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name)
metadata = (
{
k: v
for k, v in dict(wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
if not args.load_best_model_at_end
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
}
)
artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
wandb.run.log_artifact(artifact)

def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model, reinit=False)
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ def __init__(
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
self.callback_handler = CallbackHandler(
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)

# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ class TrainerCallback:
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
The model being trained.
tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer used for encoding the data.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
Expand Down Expand Up @@ -274,11 +276,12 @@ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, contr
class CallbackHandler(TrainerCallback):
""" Internal class that just calls the list of callbacks in order. """

def __init__(self, callbacks, model, optimizer, lr_scheduler):
def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
self.callbacks = []
for cb in callbacks:
self.add_callback(cb)
self.model = model
self.tokenizer = tokenizer
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dataloader = None
Expand Down Expand Up @@ -376,6 +379,7 @@ def call_event(self, event, args, state, control, **kwargs):
state,
control,
model=self.model,
tokenizer=self.tokenizer,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,
Expand Down

0 comments on commit 30fa0b7

Please sign in to comment.