From 30fa0b780f30efacdfe3f0964bb1b941b22aa8d5 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Tue, 5 Jan 2021 02:30:46 -0600 Subject: [PATCH] feat(wandb): save model as artifact (#8119) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(wandb): log artifacts * fix: typo * feat(wandb): ensure name is allowed * feat(wandb): log artifact * feat(wandb): saving logic * style: improve formatting * fix: unrelated typo * feat: use a fake trainer * fix: simplify * feat(wandb): log model files as artifact * style: fix style * docs(wandb): correct description * feat: unpack model + allow env Truethy values * feat: TrainerCallback can access tokenizer * style: fix style * feat(wandb): log more interesting metadata * feat: unpack tokenizer * feat(wandb): metadata with load_best_model_at_end * feat(wandb): more robust metadata * style(wandb): fix formatting --- src/transformers/integrations.py | 41 +++++++++++++++++++++++++++- src/transformers/trainer.py | 4 ++- src/transformers/trainer_callback.py | 6 +++- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index 2d673087e832dd..4053582d3aef2d 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -15,8 +15,13 @@ Integrations with other Python libraries. """ import math +import numbers import os +import re +import tempfile +from pathlib import Path +from .file_utils import ENV_VARS_TRUE_VALUES from .trainer_utils import EvaluationStrategy from .utils import logging @@ -369,6 +374,8 @@ def setup(self, args, state, model, reinit, **kwargs): `__. You can also override the following environment variables: Environment: + WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to log model as artifact at the end of training. WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`): Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient logging or :obj:`"all"` to log gradients and parameters. @@ -407,12 +414,44 @@ def setup(self, args, state, model, reinit, **kwargs): if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false": wandb.watch(model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)) + # log outputs + self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"}) + def on_train_begin(self, args, state, control, model=None, **kwargs): hp_search = state.is_hyper_param_search if not self._initialized or hp_search: - print(args.run_name) self.setup(args, state, model, reinit=hp_search, **kwargs) + def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs): + # commit last step + wandb.log({}) + if self._log_model and self._initialized and state.is_world_process_zero: + from .trainer import Trainer + + fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer) + with tempfile.TemporaryDirectory() as temp_dir: + fake_trainer.save_model(temp_dir) + # use run name and ensure it's a valid Artifact name + artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", wandb.run.name) + metadata = ( + { + k: v + for k, v in dict(wandb.summary).items() + if isinstance(v, numbers.Number) and not k.startswith("_") + } + if not args.load_best_model_at_end + else { + f"eval/{args.metric_for_best_model}": state.best_metric, + "train/total_floss": state.total_flos, + } + ) + artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata) + for f in Path(temp_dir).glob("*"): + if f.is_file(): + with artifact.new_file(f.name, mode="wb") as fa: + fa.write(f.read_bytes()) + wandb.run.log_artifact(artifact) + def on_log(self, args, state, control, model=None, logs=None, **kwargs): if not self._initialized: self.setup(args, state, model, reinit=False) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index d76ba6be7b7625..f3c21e2d617a32 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -261,7 +261,9 @@ def __init__( "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) callbacks = DEFAULT_CALLBACKS if callbacks is None else DEFAULT_CALLBACKS + callbacks - self.callback_handler = CallbackHandler(callbacks, self.model, self.optimizer, self.lr_scheduler) + self.callback_handler = CallbackHandler( + callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler + ) self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. diff --git a/src/transformers/trainer_callback.py b/src/transformers/trainer_callback.py index 1ad546bc4f1ee6..ea2ea3cd82d684 100644 --- a/src/transformers/trainer_callback.py +++ b/src/transformers/trainer_callback.py @@ -168,6 +168,8 @@ class TrainerCallback: The object that is returned to the :class:`~transformers.Trainer` and can be used to make some decisions. model (:class:`~transformers.PreTrainedModel` or :obj:`torch.nn.Module`): The model being trained. + tokenizer (:class:`~transformers.PreTrainedTokenizer`): + The tokenizer used for encoding the data. optimizer (:obj:`torch.optim.Optimizer`): The optimizer used for the training steps. lr_scheduler (:obj:`torch.optim.lr_scheduler.LambdaLR`): @@ -274,11 +276,12 @@ def on_prediction_step(self, args: TrainingArguments, state: TrainerState, contr class CallbackHandler(TrainerCallback): """ Internal class that just calls the list of callbacks in order. """ - def __init__(self, callbacks, model, optimizer, lr_scheduler): + def __init__(self, callbacks, model, tokenizer, optimizer, lr_scheduler): self.callbacks = [] for cb in callbacks: self.add_callback(cb) self.model = model + self.tokenizer = tokenizer self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.train_dataloader = None @@ -376,6 +379,7 @@ def call_event(self, event, args, state, control, **kwargs): state, control, model=self.model, + tokenizer=self.tokenizer, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, train_dataloader=self.train_dataloader,