Skip to content

Commit

Permalink
Floating-point operations logging in trainer (huggingface#6768)
Browse files Browse the repository at this point in the history
* neFLOs calculation, logging, and reloading (#1)

* testing distributed consecutive batches

* fixed AttributeError from DataParallel

* removed verbosity

* rotate with use_mtime=True

* removed print

* fixed interaction with gradient accumulation

* indent formatting

* distributed neflo counting

* fixed typo

* fixed typo

* mean distributed losses

* exporting log history

* moved a few functions

* floating_point_ops clarification for transformers with parameter-reuse

* code quality

* double import

* made flo estimation more task-agnostic

* only logging flos if computed

* code quality

* unused import

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Sylvain review

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* black

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
  • Loading branch information
TevenLeScao and sgugger authored Sep 8, 2020
1 parent d155b38 commit 01d340a
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 39 deletions.
88 changes: 71 additions & 17 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import inspect
import os
import re
import warnings
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import torch
from torch import Tensor, device, dtype, nn
Expand All @@ -45,7 +46,6 @@

logger = logging.get_logger(__name__)


try:
from torch.nn import Identity
except ImportError:
Expand Down Expand Up @@ -91,20 +91,6 @@ class ModuleUtilsMixin:
A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin.
"""

def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get the number of (optionally, trainable) parameters in the model.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
Returns:
:obj:`int`: The number of parameters.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)

@staticmethod
def _hook_rss_memory_pre_forward(module, *args, **kwargs):
try:
Expand Down Expand Up @@ -307,9 +293,77 @@ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
head_mask = head_mask.to(dtype=self.dtype) # switch to fload if need + fp16 compatibility
head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
return head_mask

def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
"""
Get number of (optionally, trainable or non-embeddings) parameters in the module.
Args:
only_trainable (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of trainable parameters
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to return only the number of non-embeddings parameters
Returns:
:obj:`int`: The number of parameters.
"""

def parameter_filter(x):
return (x.requires_grad or not only_trainable) and not (
isinstance(x, torch.nn.Embedding) and exclude_embeddings
)

params = filter(parameter_filter, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)

def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int:
"""
Helper function to estimate the total number of tokens from the model inputs.
Args:
inputs (:obj:`dict`): The model inputs.
Returns:
:obj:`int`: The total number of tokens.
"""
token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key]
if token_inputs:
return sum([token_input.numel() for token_input in token_inputs])
else:
warnings.warn(
"Could not estimate the number of tokens of the input, floating-point operations will not be computed"
)
return 0

def floating_point_ops(
self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
) -> int:
"""
Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
tokens (valid if :obj:`12 * d_model << sequence_length`) as laid out in `this paper <https://arxiv.org/pdf/2001.08361.pdf>`__ section
2.1. Should be overriden for transformers with parameter re-use e.g. Albert or Universal Transformers, or
if doing long-range modeling with very high sequence lengths.
Args:
batch_size (:obj:`int`):
The batch size for the forward pass.
sequence_length (:obj:`int`):
The number of tokens in each line of the batch.
exclude_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to count embedding and softmax operations.
Returns:
:obj:`int`: The number of floating-point operations.
"""

return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
r"""
Expand Down
106 changes: 85 additions & 21 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import json
import math
import os
import re
Expand Down Expand Up @@ -42,6 +43,8 @@
TrainOutput,
default_compute_objective,
default_hp_space,
distributed_broadcast_scalars,
distributed_concat,
set_seed,
)
from .training_args import TrainingArguments
Expand Down Expand Up @@ -146,7 +149,7 @@ def __iter__(self):
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert (
len(indices) == self.num_samples
), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"

return iter(indices)

Expand Down Expand Up @@ -241,6 +244,7 @@ def __init__(
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
self.tb_writer = tb_writer
self.log_history = []
if "prediction_loss_only" in kwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
Expand Down Expand Up @@ -284,6 +288,7 @@ def __init__(

self.global_step = None
self.epoch = None
self.total_flos = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
Expand Down Expand Up @@ -461,7 +466,11 @@ def setup_wandb(self):
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
try:
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
except AttributeError:
# in case the model has no config
combined_dict = {**self.args.to_sanitized_dict()}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
Expand Down Expand Up @@ -663,13 +672,16 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D

self.global_step = 0
self.epoch = 0
self.total_flos = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if model_path is not None:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(model.config, "total_flos", 0)

epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % (
len(train_dataloader) // self.args.gradient_accumulation_steps
Expand All @@ -678,9 +690,11 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0
self.total_flos = 0
logger.info(" Starting fine-tuning.")

tr_loss = torch.tensor(0.0).to(self.args.device)
Expand Down Expand Up @@ -714,6 +728,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
continue

tr_loss += self.training_step(model, inputs)
self.total_flos += self.floating_point_ops(inputs)

if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
Expand Down Expand Up @@ -784,7 +799,7 @@ def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", D
self.save_model(output_dir)

if self.is_world_process_zero():
self._rotate_checkpoints()
self._rotate_checkpoints(use_mtime=True)

if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
Expand Down Expand Up @@ -924,6 +939,13 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:

if self.epoch is not None:
logs["epoch"] = self.epoch
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
logs["total_flos"] = self.total_flos
if self.global_step is None:
# when logging evaluation metrics without training
self.global_step = 0
Expand Down Expand Up @@ -951,6 +973,8 @@ def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
if experiment is not None:
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
output = {**logs, **{"step": self.global_step}}
if self.is_world_process_zero():
self.log_history.append(output)
if iterator is not None:
iterator.write(output)
else:
Expand Down Expand Up @@ -1089,13 +1113,17 @@ def _save_tpu(self, output_dir: Optional[str] = None):
if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)

# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")

xm.rendezvous("saving_checkpoint")
self._store_flos()
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
Expand All @@ -1108,12 +1136,26 @@ def _save(self, output_dir: Optional[str] = None):
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self._store_flos()
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)

def _store_flos(self):
# Storing the number of floating-point operations that went into the model
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
self.model.config.total_flos = total_flos

def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
Expand Down Expand Up @@ -1245,13 +1287,11 @@ def prediction_loop(
self._past = None

disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
samples_count = 0
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size
if loss is not None:
eval_losses.append(loss * batch_size)
eval_losses.extend([loss] * batch_size)
if logits is not None:
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
if labels is not None:
Expand All @@ -1264,9 +1304,9 @@ def prediction_loop(
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
Expand All @@ -1289,7 +1329,14 @@ def prediction_loop(
else:
metrics = {}
if len(eval_losses) > 0:
metrics["eval_loss"] = np.sum(eval_losses) / samples_count
if self.args.local_rank != -1:
metrics["eval_loss"] = (
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
.mean()
.item()
)
else:
metrics["eval_loss"] = np.mean(eval_losses)

# Prefix all keys with eval_
for key in list(metrics.keys()):
Expand All @@ -1298,18 +1345,6 @@ def prediction_loop(

return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
assert self.args.local_rank != -1

output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)

concat = torch.cat(output_tensors, dim=0)

# truncate the dummy elements added by SequentialDistributedSampler
output = concat[:num_total_examples]
return output

def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
Expand Down Expand Up @@ -1355,3 +1390,32 @@ def prediction_step(
if labels is not None:
labels = labels.detach()
return (loss, logits.detach(), labels)

def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
For models that inherit from :class:`~transformers.PretrainedModel`, uses
that method to compute the number of floating point operations for every backward + forward pass. If using
another model, either implement such a method in the model or subclass and override this method.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
Returns:
:obj:`int`: The number of floating-point operations.
"""

if isinstance(self.model, torch.nn.DataParallel) or isinstance(
self.model, torch.nn.parallel.DistributedDataParallel
):
model = self.model.module
else:
model = self.model

if hasattr(model, "floating_point_ops"):
return model.floating_point_ops(inputs)

else:
return 0
Loading

0 comments on commit 01d340a

Please sign in to comment.