Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wandb): save model as artifact #8119

Merged
merged 22 commits into from
Jan 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@
Integrations with other Python libraries.
"""
import math
import numbers
import os
import re
import tempfile
from pathlib import Path

from .file_utils import ENV_VARS_TRUE_VALUES
from .trainer_utils import EvaluationStrategy
from .utils import logging

Expand Down Expand Up @@ -346,6 +351,8 @@ def setup(self, args, state, model, reinit, **kwargs):
<https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:

Environment:
WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to log model as artifact at the end of training.
WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
logging or :obj:`"all"` to log gradients and parameters.
Expand Down Expand Up @@ -384,12 +391,44 @@ def setup(self, args, state, model, reinit, **kwargs):
if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps))

# log outputs
self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})

def on_train_begin(self, args, state, control, model=None, **kwargs):
hp_search = state.is_hyper_param_search
if not self._initialized or hp_search:
print(args.run_name)
self.setup(args, state, model, reinit=hp_search, **kwargs)

def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
# commit last step
wandb.log({})
if self._log_model and self._initialized and state.is_world_process_zero:
from .trainer import Trainer

fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
with tempfile.TemporaryDirectory() as temp_dir:
fake_trainer.save_model(temp_dir)
# use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name)
metadata = (
{
k: v
for k, v in dict(wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
if not args.load_best_model_at_end
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
}
)
artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
wandb.run.log_artifact(artifact)

def on_log(self, args, state, control, model=None, logs=None, **kwargs):
if not self._initialized:
self.setup(args, state, model, reinit=False)
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ def __init__(
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks
self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler)
self.callback_handler = CallbackHandler(
callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)

# Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ class TrainerCallback:
The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions.
model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`):
The model being trained.
tokenizer (:class:`~transformers.PreTrainedTokenizer`):
The tokenizer used for encoding the data.
optimizer (:obj:`torch.optim.Optimizer`):
The optimizer used for the training steps.
lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`):
Expand Down Expand Up @@ -274,11 +276,12 @@ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, contr
class CallbackHandler(TrainerCallback):
""" Internal class that just calls the list of callbacks in order. """

def __init__(self, callbacks, model, optimizer, lr_scheduler):
def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler):
self.callbacks = []
for cb in callbacks:
self.add_callback(cb)
self.model = model
self.tokenizer = tokenizer
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.train_dataloader = None
Expand Down Expand Up @@ -376,6 +379,7 @@ def call_event(self, event, args, state, control, **kwargs):
state,
control,
model=self.model,
tokenizer=self.tokenizer,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,
Expand Down