Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add predict step accumulation #7767

Merged
merged 10 commits into from
Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/internal/trainer_utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ Callbacks internals
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.trainer_callback.CallbackHandler

Distributed Evaluation
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.trainer_pt_utils.DistributedTensorGatherer
:members:
2 changes: 1 addition & 1 deletion examples/seq2seq/seq2seq_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def prediction_step(
# Call forward again to get loss # TODO: avoidable?
outputs = model(**inputs, use_cache=False)
loss = self._compute_loss(outputs[1], labels_out)
loss = loss.mean().item()
loss = loss.mean().detach()
if self.args.prediction_loss_only:
return (loss, None, None)

Expand Down
11 changes: 11 additions & 0 deletions examples/test_xla_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,14 @@ def test_run_glue(self):

# Assert that the script takes less than 300 seconds to make sure it doesn't hang.
self.assertLess(end - start, 300)

def test_trainer_tpu(self):
import xla_spawn

testargs = """
transformers/tests/test_trainer_tpu.py
--num_cores=8
transformers/tests/test_trainer_tpu.py
""".split()
with patch.object(sys, "argv", testargs):
xla_spawn.main()
100 changes: 58 additions & 42 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
TrainerState,
)
from .trainer_pt_utils import (
DistributedTensorGatherer,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
Expand Down Expand Up @@ -1249,18 +1250,29 @@ def prediction_loop(
# multi-gpu eval
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
else:
model = self.model
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.

batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader)
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", self.num_examples(dataloader))
logger.info(" Num examples = %d", num_examples)
logger.info(" Batch size = %d", batch_size)
eval_losses: List[float] = []
preds: torch.Tensor = None
label_ids: torch.Tensor = None
losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None

world_size = 1
if is_torch_tpu_available():
world_size = xm.xrt_world_size()
elif self.args.local_rank != -1:
world_size = torch.distributed.get_world_size()
world_size = max(1, world_size)

eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)

model.eval()

if is_torch_tpu_available():
Expand All @@ -1271,55 +1283,46 @@ def prediction_loop(

self.callback_handler.eval_dataloader = dataloader

for inputs in dataloader:
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
if loss is not None:
eval_losses.extend([loss] * batch_size)
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None:
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
if labels is not None:
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)

# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))

# Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None

if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")

if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
preds = nested_xla_mesh_reduce(preds, "eval_preds")
if label_ids is not None:
label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()

# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:
preds = nested_numpify(preds)
if label_ids is not None:
label_ids = nested_numpify(label_ids)
# Gather all remaining tensors and put them back on the CPU
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))

eval_loss = eval_losses_gatherer.finalize()
preds = preds_gatherer.finalize()
label_ids = labels_gatherer.finalize()

if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else:
metrics = {}
if len(eval_losses) > 0:
if self.args.local_rank != -1:
metrics["eval_loss"] = (
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
.mean()
.item()
)
else:
metrics["eval_loss"] = np.mean(eval_losses)

if eval_loss is not None:
metrics["eval_loss"] = eval_loss.mean().item()

# Prefix all keys with eval_
for key in list(metrics.keys()):
Expand All @@ -1328,6 +1331,20 @@ def prediction_loop(

return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

def _gather_and_numpify(self, tensors, name):
"""
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
concatenating them to `gathered`
"""
if tensors is None:
return
if is_torch_tpu_available():
tensors = nested_xla_mesh_reduce(tensors, name)
elif self.args.local_rank != -1:
tensors = distributed_concat(tensors)

return nested_numpify(tensors)

def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
Expand Down Expand Up @@ -1357,8 +1374,7 @@ def prediction_step(
with torch.no_grad():
outputs = model(**inputs)
if has_labels:
# The .mean() is to reduce in case of distributed training
loss = outputs[0].mean().item()
loss = outputs[0].mean().detach()
logits = outputs[1:]
else:
loss = None
Expand Down
120 changes: 119 additions & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,22 @@
from contextlib import contextmanager
from typing import List, Optional, Union

import numpy as np
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, Sampler

from .file_utils import is_torch_tpu_available
from .utils import logging


if is_torch_tpu_available():
import torch_xla.core.xla_model as xm

PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler."

logger = logging.get_logger(__name__)


def nested_concat(tensors, new_tensors, dim=0):
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
Expand All @@ -41,7 +45,12 @@ def nested_concat(tensors, new_tensors, dim=0):
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
return torch.cat((tensors, new_tensors), dim=dim)
elif isinstance(tensors, torch.Tensor):
return torch.cat((tensors, new_tensors), dim=dim)
elif isinstance(tensors, np.ndarray):
return np.concatenate((tensors, new_tensors), axis=dim)
else:
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")


def nested_numpify(tensors):
Expand Down Expand Up @@ -177,3 +186,112 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())


def nested_new_like(arrays, num_samples):
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
if isinstance(arrays, (list, tuple)):
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype)


def nested_truncate(tensors, limit):
"Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_truncate(t, limit) for t in tensors)
return tensors[:limit]


class DistributedTensorGatherer:
"""
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU
by chunks.

If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on
CPU at every step, our sampler will generate the following indices:

:obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`

to get something of size a multiple of 3 (so that each process gets the same dataset length). Then
process 0, 1 and 2 will be responsible of making predictions for the following samples:

- P0: :obj:`[0, 1, 2, 3, 4, 5]`
- P1: :obj:`[6, 7, 8, 9, 10, 11]`
- P2: :obj:`[12, 13, 14, 15, 0, 1]`

The first batch treated on each process will be

- P0: :obj:`[0, 1]`
- P1: :obj:`[6, 7]`
- P2: :obj:`[12, 13]`

So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor)
corresponding to the following indices:

:obj:`[0, 1, 6, 7, 12, 13]`

If we directly concatenate our results without taking any precautions, the user will then get
the predictions for the indices in this order at the end of the prediction loop:

:obj:`[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`

For some reason, that's not going to roll their boat. This class is there to solve that problem.

Args:

world_size (:obj:`int`):
The number of processes used in the distributed training.
num_samples (:obj:`int`):
The number of samples in our dataset.
make_multiple_of (:obj:`int`, `optional`):
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
(by adding samples).
"""

def __init__(self, world_size, num_samples, make_multiple_of=None):
self.world_size = world_size
self.num_samples = num_samples
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
self.process_length = self.total_samples // world_size
self._storage = None
self._offsets = None

def add_arrays(self, arrays):
"""
Add :obj:`arrays` to the internal storage, Will initialize the storage to the full size at the first arrays
passed so that if we're bound to get an OOM, it happens at the beginning.
"""
if arrays is None:
return
if self._storage is None:
self._storage = nested_new_like(arrays, self.total_samples)
self._offsets = list(range(0, self.total_samples, self.process_length))
slice_len = self._nested_set_tensors(self._storage, arrays)
for i in range(self.world_size):
self._offsets[i] += slice_len

def _nested_set_tensors(self, storage, arrays):
if isinstance(arrays, (list, tuple)):
for x, y in zip(storage, arrays):
slice_len = self._nested_set_tensors(x, y)
return slice_len
assert (
arrays.shape[0] % self.world_size == 0
), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."

slice_len = arrays.shape[0] // self.world_size
for i in range(self.world_size):
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
return slice_len

def finalize(self):
"""
Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras
to get each process a dataset of the same length).
"""
if self._storage is None:
return
if self._offsets[0] != self.process_length:
logger.warn("Not all data has been set. Are you sure you passed all values?")
return nested_truncate(self._storage, self.num_samples)
10 changes: 9 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,18 @@ class TrainingArguments:
The batch size per GPU/TPU core/CPU for training.
per_device_eval_batch_size (:obj:`int`, `optional`, defaults to 8):
The batch size per GPU/TPU core/CPU for evaluation.
gradient_accumulation_steps: (:obj:`int`, `optional`, defaults to 1):
gradient_accumulation_steps (:obj:`int`, `optional`, defaults to 1):
Number of updates steps to accumulate the gradients for, before performing a backward/update pass.

.. warning::

When using gradient accumulation, one step is counted as one step with backward pass. Therefore,
logging, evaluation, save will be conducted every ``gradient_accumulation_steps * xxx_step`` training
examples.
eval_accumulation_steps (:obj:`int`, `optional`):
Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If
left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster but
requires more memory).
learning_rate (:obj:`float`, `optional`, defaults to 5e-5):
The initial learning rate for Adam.
weight_decay (:obj:`float`, `optional`, defaults to 0):
Expand Down Expand Up @@ -225,6 +229,10 @@ class TrainingArguments:
default=1,
metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."},
)
eval_accumulation_steps: Optional[int] = field(
default=None,
metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."},
)

learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."})
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."})
Expand Down
15 changes: 15 additions & 0 deletions tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2018 the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import os
import tempfile
Expand Down
Loading