diff --git a/docs/source/internal/trainer_utils.rst b/docs/source/internal/trainer_utils.rst index 48e8568b9530be..97bf5d1c8eaef1 100644 --- a/docs/source/internal/trainer_utils.rst +++ b/docs/source/internal/trainer_utils.rst @@ -19,3 +19,9 @@ Callbacks internals ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.trainer_callback.CallbackHandler + +Distributed Evaluation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.trainer_pt_utils.DistributedTensorGatherer + :members: \ No newline at end of file diff --git a/examples/seq2seq/seq2seq_trainer.py b/examples/seq2seq/seq2seq_trainer.py index c484499d26e467..bf002ee2b654df 100644 --- a/examples/seq2seq/seq2seq_trainer.py +++ b/examples/seq2seq/seq2seq_trainer.py @@ -174,7 +174,7 @@ def prediction_step( # Call forward again to get loss # TODO: avoidable? outputs = model(**inputs, use_cache=False) loss = self._compute_loss(outputs[1], labels_out) - loss = loss.mean().item() + loss = loss.mean().detach() if self.args.prediction_loss_only: return (loss, None, None) diff --git a/examples/test_xla_examples.py b/examples/test_xla_examples.py index 8e3aad7b988d8b..444884ddd86e26 100644 --- a/examples/test_xla_examples.py +++ b/examples/test_xla_examples.py @@ -81,3 +81,14 @@ def test_run_glue(self): # Assert that the script takes less than 300 seconds to make sure it doesn't hang. self.assertLess(end - start, 300) + + def test_trainer_tpu(self): + import xla_spawn + + testargs = """ + transformers/tests/test_trainer_tpu.py + --num_cores=8 + transformers/tests/test_trainer_tpu.py + """.split() + with patch.object(sys, "argv", testargs): + xla_spawn.main() diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3656ee1bc3c063..620454f0eb02d3 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -59,6 +59,7 @@ TrainerState, ) from .trainer_pt_utils import ( + DistributedTensorGatherer, SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, @@ -1249,18 +1250,29 @@ def prediction_loop( # multi-gpu eval if self.args.n_gpu > 1: model = torch.nn.DataParallel(model) - else: - model = self.model # Note: in torch.distributed mode, there's no point in wrapping the model # inside a DistributedDataParallel as we'll be under `no_grad` anyways. batch_size = dataloader.batch_size + num_examples = self.num_examples(dataloader) logger.info("***** Running %s *****", description) - logger.info(" Num examples = %d", self.num_examples(dataloader)) + logger.info(" Num examples = %d", num_examples) logger.info(" Batch size = %d", batch_size) - eval_losses: List[float] = [] - preds: torch.Tensor = None - label_ids: torch.Tensor = None + losses_host: torch.Tensor = None + preds_host: Union[torch.Tensor, List[torch.Tensor]] = None + labels_host: Union[torch.Tensor, List[torch.Tensor]] = None + + world_size = 1 + if is_torch_tpu_available(): + world_size = xm.xrt_world_size() + elif self.args.local_rank != -1: + world_size = torch.distributed.get_world_size() + world_size = max(1, world_size) + + eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) + preds_gatherer = DistributedTensorGatherer(world_size, num_examples) + labels_gatherer = DistributedTensorGatherer(world_size, num_examples) + model.eval() if is_torch_tpu_available(): @@ -1271,55 +1283,46 @@ def prediction_loop( self.callback_handler.eval_dataloader = dataloader - for inputs in dataloader: + for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only) - batch_size = inputs[list(inputs.keys())[0]].shape[0] if loss is not None: - eval_losses.extend([loss] * batch_size) + losses = loss.repeat(batch_size) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) if logits is not None: - preds = logits if preds is None else nested_concat(preds, logits, dim=0) + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0) if labels is not None: - label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0) + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0) self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + + # Set back to None to begin a new accumulation + losses_host, preds_host, labels_host = None, None, None + if self.args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") - if self.args.local_rank != -1: - # In distributed mode, concatenate all results from all nodes: - if preds is not None: - preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader)) - if label_ids is not None: - label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader)) - elif is_torch_tpu_available(): - # tpu-comment: Get all predictions and labels from all worker shards of eval dataset - if preds is not None: - preds = nested_xla_mesh_reduce(preds, "eval_preds") - if label_ids is not None: - label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids") - if eval_losses is not None: - eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist() - - # Finally, turn the aggregated tensors into numpy arrays. - if preds is not None: - preds = nested_numpify(preds) - if label_ids is not None: - label_ids = nested_numpify(label_ids) + # Gather all remaining tensors and put them back on the CPU + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + + eval_loss = eval_losses_gatherer.finalize() + preds = preds_gatherer.finalize() + label_ids = labels_gatherer.finalize() if self.compute_metrics is not None and preds is not None and label_ids is not None: metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) else: metrics = {} - if len(eval_losses) > 0: - if self.args.local_rank != -1: - metrics["eval_loss"] = ( - distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader)) - .mean() - .item() - ) - else: - metrics["eval_loss"] = np.mean(eval_losses) + + if eval_loss is not None: + metrics["eval_loss"] = eval_loss.mean().item() # Prefix all keys with eval_ for key in list(metrics.keys()): @@ -1328,6 +1331,20 @@ def prediction_loop( return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) + def _gather_and_numpify(self, tensors, name): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_tpu_available(): + tensors = nested_xla_mesh_reduce(tensors, name) + elif self.args.local_rank != -1: + tensors = distributed_concat(tensors) + + return nested_numpify(tensors) + def prediction_step( self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: @@ -1357,8 +1374,7 @@ def prediction_step( with torch.no_grad(): outputs = model(**inputs) if has_labels: - # The .mean() is to reduce in case of distributed training - loss = outputs[0].mean().item() + loss = outputs[0].mean().detach() logits = outputs[1:] else: loss = None diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 74a93f8286b881..45ff9c8fdf0701 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -21,11 +21,13 @@ from contextlib import contextmanager from typing import List, Optional, Union +import numpy as np import torch from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler, Sampler from .file_utils import is_torch_tpu_available +from .utils import logging if is_torch_tpu_available(): @@ -33,6 +35,8 @@ PT_LR_SCHEDULER_WARNING = "Please also save or load the state of the optimzer when saving or loading the scheduler." +logger = logging.get_logger(__name__) + def nested_concat(tensors, new_tensors, dim=0): "Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors." @@ -41,7 +45,12 @@ def nested_concat(tensors, new_tensors, dim=0): ), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}." if isinstance(tensors, (list, tuple)): return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors)) - return torch.cat((tensors, new_tensors), dim=dim) + elif isinstance(tensors, torch.Tensor): + return torch.cat((tensors, new_tensors), dim=dim) + elif isinstance(tensors, np.ndarray): + return np.concatenate((tensors, new_tensors), axis=dim) + else: + raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}") def nested_numpify(tensors): @@ -177,3 +186,112 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset): if xm.xrt_world_size() <= 1: return RandomSampler(dataset) return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()) + + +def nested_new_like(arrays, num_samples): + """ Create the same nested structure as `arrays` with a first dimension always at `num_samples`.""" + if isinstance(arrays, (list, tuple)): + return type(arrays)(nested_new_like(x, num_samples) for x in arrays) + return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype) + + +def nested_truncate(tensors, limit): + "Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)." + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_truncate(t, limit) for t in tensors) + return tensors[:limit] + + +class DistributedTensorGatherer: + """ + A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU + by chunks. + + If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on + CPU at every step, our sampler will generate the following indices: + + :obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]` + + to get something of size a multiple of 3 (so that each process gets the same dataset length). Then + process 0, 1 and 2 will be responsible of making predictions for the following samples: + + - P0: :obj:`[0, 1, 2, 3, 4, 5]` + - P1: :obj:`[6, 7, 8, 9, 10, 11]` + - P2: :obj:`[12, 13, 14, 15, 0, 1]` + + The first batch treated on each process will be + + - P0: :obj:`[0, 1]` + - P1: :obj:`[6, 7]` + - P2: :obj:`[12, 13]` + + So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor) + corresponding to the following indices: + + :obj:`[0, 1, 6, 7, 12, 13]` + + If we directly concatenate our results without taking any precautions, the user will then get + the predictions for the indices in this order at the end of the prediction loop: + + :obj:`[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]` + + For some reason, that's not going to roll their boat. This class is there to solve that problem. + + Args: + + world_size (:obj:`int`): + The number of processes used in the distributed training. + num_samples (:obj:`int`): + The number of samples in our dataset. + make_multiple_of (:obj:`int`, `optional`): + If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument + (by adding samples). + """ + + def __init__(self, world_size, num_samples, make_multiple_of=None): + self.world_size = world_size + self.num_samples = num_samples + total_size = world_size if make_multiple_of is None else world_size * make_multiple_of + self.total_samples = int(np.ceil(num_samples / total_size)) * total_size + self.process_length = self.total_samples // world_size + self._storage = None + self._offsets = None + + def add_arrays(self, arrays): + """ + Add :obj:`arrays` to the internal storage, Will initialize the storage to the full size at the first arrays + passed so that if we're bound to get an OOM, it happens at the beginning. + """ + if arrays is None: + return + if self._storage is None: + self._storage = nested_new_like(arrays, self.total_samples) + self._offsets = list(range(0, self.total_samples, self.process_length)) + slice_len = self._nested_set_tensors(self._storage, arrays) + for i in range(self.world_size): + self._offsets[i] += slice_len + + def _nested_set_tensors(self, storage, arrays): + if isinstance(arrays, (list, tuple)): + for x, y in zip(storage, arrays): + slice_len = self._nested_set_tensors(x, y) + return slice_len + assert ( + arrays.shape[0] % self.world_size == 0 + ), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}." + + slice_len = arrays.shape[0] // self.world_size + for i in range(self.world_size): + storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len] + return slice_len + + def finalize(self): + """ + Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras + to get each process a dataset of the same length). + """ + if self._storage is None: + return + if self._offsets[0] != self.process_length: + logger.warn("Not all data has been set. Are you sure you passed all values?") + return nested_truncate(self._storage, self.num_samples) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 04a9b084346fc3..6cfdc15f079299 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -67,7 +67,7 @@ class TrainingArguments: The batch size per GPU/TPU core/CPU for training. per_device_eval_batch_size (:obj:`int`, `optional`, defaults to 8): The batch size per GPU/TPU core/CPU for evaluation. - gradient_accumulation_steps: (:obj:`int`, `optional`, defaults to 1): + gradient_accumulation_steps (:obj:`int`, `optional`, defaults to 1): Number of updates steps to accumulate the gradients for, before performing a backward/update pass. .. warning:: @@ -75,6 +75,10 @@ class TrainingArguments: When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, evaluation, save will be conducted every ``gradient_accumulation_steps * xxx_step`` training examples. + eval_accumulation_steps (:obj:`int`, `optional`): + Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If + left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster but + requires more memory). learning_rate (:obj:`float`, `optional`, defaults to 5e-5): The initial learning rate for Adam. weight_decay (:obj:`float`, `optional`, defaults to 0): @@ -225,6 +229,10 @@ class TrainingArguments: default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, ) + eval_accumulation_steps: Optional[int] = field( + default=None, + metadata={"help": "Number of predictions steps to accumulate before moving the tensors to the CPU."}, + ) learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for Adam."}) weight_decay: float = field(default=0.0, metadata={"help": "Weight decay if we apply some."}) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index eb23e7eb1afec0..fe24702c6e14f6 100755 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -1,3 +1,18 @@ +# coding=utf-8 +# Copyright 2018 the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import dataclasses import os import tempfile diff --git a/tests/test_trainer_distributed.py b/tests/test_trainer_distributed.py index cdc88f9d5765c4..d9b9f7418643a1 100644 --- a/tests/test_trainer_distributed.py +++ b/tests/test_trainer_distributed.py @@ -13,15 +13,14 @@ # CUDA_VISIBLE_DEVICES=-1 python ./tests/test_trainer_distributed.py # - -import logging import sys from typing import Dict from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available +from transformers.utils import logging -logger = logging.getLogger(__name__) +logger = logging.get_logger(__name__) if is_torch_available(): @@ -101,4 +100,20 @@ def compute_metrics(p: EvalPrediction) -> Dict: logger.error(p.metrics) exit(1) + trainer.args.eval_accumulation_steps = 2 + + metrics = trainer.evaluate() + logger.info(metrics) + if metrics["eval_success"] is not True: + logger.error(metrics) + exit(1) + + p = trainer.predict(dataset) + logger.info(p.metrics) + if p.metrics["eval_success"] is not True: + logger.error(p.metrics) + exit(1) + + trainer.args.eval_accumulation_steps = None + logger.info("🔥 All distributed tests successful") diff --git a/tests/test_trainer_tpu.py b/tests/test_trainer_tpu.py new file mode 100644 index 00000000000000..6a522fc4480a4e --- /dev/null +++ b/tests/test_trainer_tpu.py @@ -0,0 +1,119 @@ +# This test is meant to be run in on an instance with TPUs like this: +# +# python examples/xla_spawn.py --num_cores=8 tests/test_trainer_tpu.py +# +# Replace 8 with the number of TPU cores you have. +# + +import sys +from typing import Dict + +from transformers import EvalPrediction, HfArgumentParser, TrainingArguments, is_torch_available +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +if is_torch_available(): + import torch + from torch import nn + from torch.utils.data.dataset import Dataset + + from transformers import Trainer + + class DummyDataset(Dataset): + def __init__(self, length: int = 101): + self.length = length + + def __len__(self): + return self.length + + def __getitem__(self, i) -> int: + return i + + class DummyDataCollator: + def __call__(self, features): + return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)} + + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + # Add some (unused) params otherwise DDP will complain. + self.fc = nn.Linear(120, 80) + + def forward(self, input_ids, labels=None): + if labels is not None: + return torch.tensor(0.0, device=input_ids.device), input_ids + else: + return input_ids + + +def main(): + parser = HfArgumentParser((TrainingArguments,)) + sys.argv += ["--output_dir", "./examples"] + training_args = parser.parse_args_into_dataclasses()[0] + + logger.warning( + "Process rank: %s, device: %s, tpu_num_cores: %s", + training_args.local_rank, + training_args.device, + training_args.tpu_num_cores, + ) + + # Essentially, what we want to verify in the distributed case is + # that we get all samples back, in the right order. + # (this is crucial for prediction for instance) + for dataset_length in [1001, 256, 15]: + dataset = DummyDataset(dataset_length) + + def compute_metrics(p: EvalPrediction) -> Dict: + sequential = list(range(len(dataset))) + success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential + return {"success": success} + + trainer = Trainer( + model=DummyModel(), + args=training_args, + data_collator=DummyDataCollator(), + eval_dataset=dataset, + compute_metrics=compute_metrics, + ) + metrics = trainer.evaluate() + logger.info(metrics) + if metrics["eval_success"] is not True: + logger.error(metrics) + exit(1) + + p = trainer.predict(dataset) + logger.info(p.metrics) + if p.metrics["eval_success"] is not True: + logger.error(p.metrics) + exit(1) + + trainer.args.eval_accumulation_steps = 2 + + metrics = trainer.evaluate() + logger.info(metrics) + if metrics["eval_success"] is not True: + logger.error(metrics) + exit(1) + + p = trainer.predict(dataset) + logger.info(p.metrics) + if p.metrics["eval_success"] is not True: + logger.error(p.metrics) + exit(1) + + trainer.args.eval_accumulation_steps = None + + logger.info("🔥 All distributed tests successful") + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py new file mode 100644 index 00000000000000..91fe33fa478ddf --- /dev/null +++ b/tests/test_trainer_utils.py @@ -0,0 +1,58 @@ +# coding=utf-8 +# Copyright 2018 the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +from transformers.file_utils import is_torch_available +from transformers.testing_utils import require_torch + + +if is_torch_available(): + from transformers.trainer_pt_utils import DistributedTensorGatherer + + +@require_torch +class TrainerUtilsTest(unittest.TestCase): + def test_distributed_tensor_gatherer(self): + # Simulate a result with a dataset of size 21, 4 processes and chunks of lengths 2, 3, 1 + world_size = 4 + num_samples = 21 + input_indices = [ + [0, 1, 6, 7, 12, 13, 18, 19], + [2, 3, 4, 8, 9, 10, 14, 15, 16, 20, 0, 1], + [5, 11, 17, 2], + ] + + predictions = np.random.normal(size=(num_samples, 13)) + gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples) + for indices in input_indices: + gatherer.add_arrays(predictions[indices]) + result = gatherer.finalize() + self.assertTrue(np.array_equal(result, predictions)) + + # With nested tensors + gatherer = DistributedTensorGatherer(world_size=world_size, num_samples=num_samples) + for indices in input_indices: + gatherer.add_arrays([predictions[indices], [predictions[indices], predictions[indices]]]) + result = gatherer.finalize() + self.assertTrue(isinstance(result, list)) + self.assertTrue(len(result), 2) + self.assertTrue(isinstance(result[1], list)) + self.assertTrue(len(result[1]), 2) + self.assertTrue(np.array_equal(result[0], predictions)) + self.assertTrue(np.array_equal(result[1][0], predictions)) + self.assertTrue(np.array_equal(result[1][1], predictions))