From 01d340adfaea70ea0997931c92c498a5a18c7435 Mon Sep 17 00:00:00 2001 From: Teven Date: Tue, 8 Sep 2020 16:00:56 +0200 Subject: [PATCH] Floating-point operations logging in trainer (#6768) * neFLOs calculation, logging, and reloading (#1) * testing distributed consecutive batches * fixed AttributeError from DataParallel * removed verbosity * rotate with use_mtime=True * removed print * fixed interaction with gradient accumulation * indent formatting * distributed neflo counting * fixed typo * fixed typo * mean distributed losses * exporting log history * moved a few functions * floating_point_ops clarification for transformers with parameter-reuse * code quality * double import * made flo estimation more task-agnostic * only logging flos if computed * code quality * unused import * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Sylvain review * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * black Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/modeling_utils.py | 88 +++++++++++++++++++----- src/transformers/trainer.py | 106 +++++++++++++++++++++++------ src/transformers/trainer_utils.py | 32 ++++++++- 3 files changed, 187 insertions(+), 39 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 41c1de3eec6068..964461d8d1cf87 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,8 +17,9 @@ import inspect import os import re +import warnings from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import torch from torch import Tensor, device, dtype, nn @@ -45,7 +46,6 @@ logger = logging.get_logger(__name__) - try: from torch.nn import Identity except ImportError: @@ -91,20 +91,6 @@ class ModuleUtilsMixin: A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin. """ - def num_parameters(self, only_trainable: bool = False) -> int: - """ - Get the number of (optionally, trainable) parameters in the model. - - Args: - only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether or not to return only the number of trainable parameters - - Returns: - :obj:`int`: The number of parameters. - """ - params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters() - return sum(p.numel() for p in params) - @staticmethod def _hook_rss_memory_pre_forward(module, *args, **kwargs): try: @@ -307,9 +293,77 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" - head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility + head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility return head_mask + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + :obj:`int`: The number of parameters. + """ + + def parameter_filter(x): + return (x.requires_grad or not only_trainable) and not ( + isinstance(x, torch.nn.Embedding) and exclude_embeddings + ) + + params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters() + return sum(p.numel() for p in params) + + def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int: + """ + Helper function to estimate the total number of tokens from the model inputs. + + Args: + inputs (:obj:`dict`): The model inputs. + + Returns: + :obj:`int`: The total number of tokens. + """ + token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key] + if token_inputs: + return sum([token_input.numel() for token_input in token_inputs]) + else: + warnings.warn( + "Could not estimate the number of tokens of the input, floating-point operations will not be computed" + ) + return 0 + + def floating_point_ops( + self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True + ) -> int: + """ + Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a + batch with this transformer model. Default approximation neglects the quadratic dependency on the number of + tokens (valid if :obj:`12 * d_model << sequence_length`) as laid out in `this paper `__ section + 2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or + if doing long-range modeling with very high sequence lengths. + + Args: + batch_size (:obj:`int`): + The batch size for the forward pass. + + sequence_length (:obj:`int`): + The number of tokens in each line of the batch. + + exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to count embedding and softmax operations. + + Returns: + :obj:`int`: The number of floating-point operations. + """ + + return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) + class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): r""" diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index aa320773adf424..68c049643f5cef 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1,4 +1,5 @@ import inspect +import json import math import os import re @@ -42,6 +43,8 @@ TrainOutput, default_compute_objective, default_hp_space, + distributed_broadcast_scalars, + distributed_concat, set_seed, ) from .training_args import TrainingArguments @@ -146,7 +149,7 @@ def __iter__(self): indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] assert ( len(indices) == self.num_samples - ), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched" + ), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched" return iter(indices) @@ -241,6 +244,7 @@ def __init__( "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." ) self.tb_writer = tb_writer + self.log_history = [] if "prediction_loss_only" in kwargs: warnings.warn( "Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.", @@ -284,6 +288,7 @@ def __init__( self.global_step = None self.epoch = None + self.total_flos = None if self.args.fp16 and _use_native_amp: self.scaler = torch.cuda.amp.GradScaler() self.hp_search_backend = None @@ -461,7 +466,11 @@ def setup_wandb(self): logger.info( 'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"' ) - combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()} + try: + combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()} + except AttributeError: + # in case the model has no config + combined_dict = {**self.args.to_sanitized_dict()} wandb.init( project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name ) @@ -663,6 +672,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D self.global_step = 0 self.epoch = 0 + self.total_flos = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint @@ -670,6 +680,8 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D # set global_step to global_step of last saved checkpoint from model path try: self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0]) + self.total_flos = getattr(model.config, "total_flos", 0) + epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps) steps_trained_in_current_epoch = self.global_step % ( len(train_dataloader) // self.args.gradient_accumulation_steps @@ -678,9 +690,11 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D logger.info(" Continuing training from checkpoint, will skip to saved global_step") logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", self.global_step) + logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except ValueError: self.global_step = 0 + self.total_flos = 0 logger.info(" Starting fine-tuning.") tr_loss = torch.tensor(0.0).to(self.args.device) @@ -714,6 +728,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D continue tr_loss += self.training_step(model, inputs) + self.total_flos += self.floating_point_ops(inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( # last step in epoch but step is always smaller than gradient_accumulation_steps @@ -784,7 +799,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D self.save_model(output_dir) if self.is_world_process_zero(): - self._rotate_checkpoints() + self._rotate_checkpoints(use_mtime=True) if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") @@ -924,6 +939,13 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: if self.epoch is not None: logs["epoch"] = self.epoch + if self.total_flos is not None: + if self.args.local_rank != -1: + total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() + else: + total_flos = self.total_flos + if total_flos > 0: + logs["total_flos"] = self.total_flos if self.global_step is None: # when logging evaluation metrics without training self.global_step = 0 @@ -951,6 +973,8 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None: if experiment is not None: experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers") output = {**logs, **{"step": self.global_step}} + if self.is_world_process_zero(): + self.log_history.append(output) if iterator is not None: iterator.write(output) else: @@ -1089,6 +1113,9 @@ def _save_tpu(self, output_dir: Optional[str] = None): if xm.is_master_ordinal(): os.makedirs(output_dir, exist_ok=True) torch.save(self.args, os.path.join(output_dir, "training_args.bin")) + json.dump( + self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False + ) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` @@ -1096,6 +1123,7 @@ def _save_tpu(self, output_dir: Optional[str] = None): raise ValueError("Trainer.model appears to not be a PreTrainedModel") xm.rendezvous("saving_checkpoint") + self._store_flos() self.model.save_pretrained(output_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) @@ -1108,12 +1136,26 @@ def _save(self, output_dir: Optional[str] = None): # They can then be reloaded using `from_pretrained()` if not isinstance(self.model, PreTrainedModel): raise ValueError("Trainer.model appears to not be a PreTrainedModel") + self._store_flos() self.model.save_pretrained(output_dir) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, "training_args.bin")) + json.dump( + self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False + ) + + def _store_flos(self): + # Storing the number of floating-point operations that went into the model + if self.total_flos is not None: + if self.args.local_rank != -1: + total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item() + else: + total_flos = self.total_flos + if total_flos > 0: + self.model.config.total_flos = total_flos def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]: ordering_and_checkpoint_path = [] @@ -1245,13 +1287,11 @@ def prediction_loop( self._past = None disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm - samples_count = 0 for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) batch_size = inputs[list(inputs.keys())[0]].shape[0] - samples_count += batch_size if loss is not None: - eval_losses.append(loss * batch_size) + eval_losses.extend([loss] * batch_size) if logits is not None: preds = logits if preds is None else torch.cat((preds, logits), dim=0) if labels is not None: @@ -1264,9 +1304,9 @@ def prediction_loop( if self.args.local_rank != -1: # In distributed mode, concatenate all results from all nodes: if preds is not None: - preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) + preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) if label_ids is not None: - label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) + label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) elif is_torch_tpu_available(): # tpu-comment: Get all predictions and labels from all worker shards of eval dataset if preds is not None: @@ -1289,7 +1329,14 @@ def prediction_loop( else: metrics = {} if len(eval_losses) > 0: - metrics["eval_loss"] = np.sum(eval_losses) / samples_count + if self.args.local_rank != -1: + metrics["eval_loss"] = ( + distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader)) + .mean() + .item() + ) + else: + metrics["eval_loss"] = np.mean(eval_losses) # Prefix all keys with eval_ for key in list(metrics.keys()): @@ -1298,18 +1345,6 @@ def prediction_loop( return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) - def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor: - assert self.args.local_rank != -1 - - output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] - torch.distributed.all_gather(output_tensors, tensor) - - concat = torch.cat(output_tensors, dim=0) - - # truncate the dummy elements added by SequentialDistributedSampler - output = concat[:num_total_examples] - return output - def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -1355,3 +1390,32 @@ def prediction_step( if labels is not None: labels = labels.detach() return (loss, logits.detach(), labels) + + def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): + """ + For models that inherit from :class:`~transformers.PretrainedModel`, uses + that method to compute the number of floating point operations for every backward + forward pass. If using + another model, either implement such a method in the model or subclass and override this method. + + Args: + model (:obj:`nn.Module`): + The model to evaluate. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + Returns: + :obj:`int`: The number of floating-point operations. + """ + + if isinstance(self.model, torch.nn.DataParallel) or isinstance( + self.model, torch.nn.parallel.DistributedDataParallel + ): + model = self.model.module + else: + model = self.model + + if hasattr(model, "floating_point_ops"): + return model.floating_point_ops(inputs) + + else: + return 0 diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index f9b560b4026a0e..09149dc963891b 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -1,7 +1,8 @@ import random -from typing import Any, Dict, NamedTuple, Optional +from typing import Any, Dict, List, NamedTuple, Optional, Union import numpy as np +import torch from .file_utils import is_tf_available, is_torch_available from .tokenization_utils_base import ExplicitEnum @@ -126,3 +127,32 @@ class HPSearchBackend(ExplicitEnum): HPSearchBackend.OPTUNA: default_hp_space_optuna, HPSearchBackend.RAY: default_hp_space_ray, } + + +def distributed_concat(self, tensor: torch.Tensor, num_total_examples: Optional[int] = None) -> torch.Tensor: + assert self.args.local_rank != -1 + + output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output_tensors, tensor) + concat = torch.cat(output_tensors, dim=0) + + # truncate the dummy elements added by SequentialDistributedSampler + if num_total_examples is not None: + concat = concat[:num_total_examples] + return concat + + +def distributed_broadcast_scalars( + self, scalars: List[Union[int, float]], num_total_examples: Optional[int] = None +) -> torch.Tensor: + assert self.args.local_rank != -1 + + tensorized_scalar = torch.Tensor(scalars).cuda() + output_tensors = [tensorized_scalar.clone() for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(output_tensors, tensorized_scalar) + concat = torch.cat(output_tensors, dim=0) + + # truncate the dummy elements added by SequentialDistributedSampler + if num_total_examples is not None: + concat = concat[:num_total_examples] + return concat