From be255de3063547df3cedbb0b98da1729a3b615f8 Mon Sep 17 00:00:00 2001 From: chaton Date: Sat, 9 Jan 2021 13:37:44 +0100 Subject: [PATCH] Bugfix/all gather (#5221) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * resolve bug * add tests * add tests * resolve flake8 * update * update * remove globals * typo * Update pytorch_lightning/utilities/distributed.py Co-authored-by: Jirka Borovec * update * update * add suport int, float * update * resolve pep8 * Update pytorch_lightning/core/lightning.py Co-authored-by: Adrian Wälchli * Update tests/utilities/test_all_gather_grad.py Co-authored-by: Adrian Wälchli * update doc * add bool and np.ndarray * resolve conflicts * resolve conflicts * resolve pep8 * add changelog * Update pytorch_lightning/core/lightning.py Co-authored-by: Adrian Wälchli Co-authored-by: Ubuntu Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 ++ pytorch_lightning/core/lightning.py | 33 +++++++++---- pytorch_lightning/utilities/apply_func.py | 48 +++++++++++++++---- pytorch_lightning/utilities/distributed.py | 2 +- tests/special_tests.sh | 2 + tests/utilities/test_all_gather_grad.py | 56 +++++++++++++++++++++- 6 files changed, 125 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c6a2c0ab289e..3983e3416e3a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `BackboneLambdaFinetuningCallback` ([#5377](https://github.com/PyTorchLightning/pytorch-lightning/pull/5377)) +- Accelerator `all_gather` supports collection ([#5221](https://github.com/PyTorchLightning/pytorch-lightning/pull/5221)) + + - Added `image_gradients` functional metric to compute the image gradients of a given input image. ([#5056](https://github.com/PyTorchLightning/pytorch-lightning/pull/5056)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2f3e85955d82b..c2ec67819912e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -14,15 +14,16 @@ """nn.Module with additional great features.""" +from abc import ABC +from argparse import Namespace import collections import copy +from functools import partial import inspect import os +from pathlib import Path import re import tempfile -from abc import ABC -from argparse import Namespace -from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -35,10 +36,12 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO +from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args @@ -364,7 +367,12 @@ def __auto_choose_log_on_epoch(self, on_epoch): return on_epoch - def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False): + def all_gather( + self, + data: Union[torch.Tensor, Dict, List, Tuple], + group: Optional[Any] = None, + sync_grads: bool = False, + ): r""" Allows users to call ``self.all_gather()`` from the LightningModule, thus making the ```all_gather``` operation accelerator agnostic. @@ -373,14 +381,23 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s distributed processes Args: - tensor: tensor of shape (batch, ...) + tensor: int, float, tensor of shape (batch, ...), or a (possibly nested) collection thereof. group: the process group to gather results from. Defaults to all processes (world) sync_grads: flag that allows users to synchronize gradients for all_gather op Return: - A tensor of shape (world_size, batch, ...) + A tensor of shape (world_size, batch, ...), or if the input was a collection + the output will also be a collection with tensors of this shape. """ - return self.trainer.accelerator_backend.all_gather(tensor, group=group, sync_grads=sync_grads) + group = group if group is not None else torch.distributed.group.WORLD + if self.trainer.accelerator_backend is not None: + all_gather = self.trainer.accelerator_backend.all_gather + else: + all_gather = all_gather_ddp_if_available + + data = convert_to_tensors(data, device=self.device) + all_gather = partial(all_gather, group=group, sync_grads=sync_grads) + return apply_to_collection(data, torch.Tensor, all_gather) def forward(self, *args, **kwargs): r""" diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 95edb16c27b00..fd610ebcb0c8d 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -15,10 +15,13 @@ from abc import ABC from collections.abc import Mapping, Sequence from copy import copy -from typing import Any, Callable, Union, Optional +from functools import partial +from typing import Any, Callable, Optional, Union +import numpy as np import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE if _TORCHTEXT_AVAILABLE: @@ -27,11 +30,35 @@ Batch = type(None) +def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None): + if device is None: + raise MisconfigurationException( + "device (torch.device) should be provided." + ) + return torch.tensor(value, dtype=dtype, device=device) + + +def from_numpy(value, device: torch.device = None): + if device is None: + raise MisconfigurationException( + "device (torch.device) should be provided." + ) + return torch.from_numpy(value).to(device) + + +CONVERSION_DTYPES = [ + # bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group + (bool, partial(to_dtype_tensor, dtype=torch.uint8)), + (int, partial(to_dtype_tensor, dtype=torch.int)), + (float, partial(to_dtype_tensor, dtype=torch.float)), + (np.ndarray, from_numpy), +] + + def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args, wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any: """ Recursively applies a function to all elements of a certain dtype. - Args: data: the collection to apply the function to dtype: the given function will be applied to all elements of this dtype @@ -40,10 +67,8 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the :attr:`wrong_type` even if it is of type :attr`dtype` **kwargs: keyword arguments (will be forwarded to calls of ``function``) - Returns: the resulting collection - """ elem_type = type(data) @@ -67,9 +92,7 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable class TransferableDataType(ABC): """ A custom type for data that can be moved to a torch device via `.to(...)`. - Example: - >>> isinstance(dict, TransferableDataType) False >>> isinstance(torch.rand(2, 3), TransferableDataType) @@ -96,15 +119,12 @@ def move_data_to_device(batch: Any, device: torch.device): """ Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. - Args: batch: A tensor or collection of tensors or anything that has a method `.to(...)`. See :func:`apply_to_collection` for a list of supported collection types. device: The device to which the data should be moved - Return: the same collection but with all contained tensors residing on the new device. - See Also: - :meth:`torch.Tensor.to` - :class:`torch.device` @@ -128,3 +148,13 @@ def batch_to(data): dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType return apply_to_collection(batch, dtype=dtype, function=batch_to) + + +def convert_to_tensors(data, device: torch.device = None): + if device is None: + raise MisconfigurationException( + "device (torch.device) should be provided." + ) + for src_dtype, conversion_func in CONVERSION_DTYPES: + data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device)) + return data diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 2a0b989e9b9cd..3460d85c64131 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -22,7 +22,7 @@ from pytorch_lightning import _logger as log if torch.distributed.is_available(): - from torch.distributed import ReduceOp, group + from torch.distributed import group, ReduceOp else: class ReduceOp: SUM = None diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 8d67cce28b39f..675f05cf787ff 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -20,4 +20,6 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic +python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection +# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp diff --git a/tests/utilities/test_all_gather_grad.py b/tests/utilities/test_all_gather_grad.py index faba88236afd0..9d0dc5cbc9481 100644 --- a/tests/utilities/test_all_gather_grad.py +++ b/tests/utilities/test_all_gather_grad.py @@ -1,9 +1,13 @@ import os -import pytest import sys + +import numpy as np +import pytest import torch +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.utilities import AllGatherGrad +from tests.base.boring_model import BoringModel def setup_ddp(rank, world_size): @@ -41,3 +45,53 @@ def _test_all_gather_ddp(rank, world_size): def test_all_gather_ddp(): world_size = 3 torch.multiprocessing.spawn(_test_all_gather_ddp, args=(world_size,), nprocs=world_size) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") +def test_all_gather_collection(tmpdir): + + class TestModel(BoringModel): + + training_epoch_end_called = False + + def training_epoch_end(self, outputs) -> None: + self.training_epoch_end_called = True + losses = torch.stack([x["loss"] for x in outputs]) + gathered_loss = self.all_gather({ + "losses_np_ndarray": np.array([1, 2, 3]), + "losses_bool": [True, False], + "losses_float": [0., 1., 2.], + "losses_int": [0, 1, 2], + "losses": losses, + "losses_list": [losses, losses] + }) + assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64 + # torch.bool can't be all_gathered + assert gathered_loss["losses_bool"][0].dtype == torch.uint8 + assert gathered_loss["losses_float"][0].dtype == torch.float + assert gathered_loss["losses_int"][0].dtype == torch.int + assert gathered_loss["losses_list"][0].numel() == 2 * len(losses) + assert gathered_loss["losses"].numel() == 2 * len(losses) + + seed_everything(42) + + model = TestModel() + + limit_train_batches = 8 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + accumulate_grad_batches=2, + enable_pl_optimizer=True, + gpus=2, + accelerator="ddp", + ) + + trainer.fit(model) + assert model.training_epoch_end_called