From aa5bf1f03ce219704ee45a98051821ac4a2e55b0 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sun, 11 Apr 2021 02:20:44 +0530 Subject: [PATCH 01/11] Fix sync_dist for tpus --- pytorch_lightning/core/step_result.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index eb0f26cec2bbc..4af9c0fdd24be 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -21,8 +21,6 @@ from torch import Tensor from torchmetrics import Metric -from pytorch_lightning.utilities.distributed import sync_ddp_if_available - class Result(Dict): @@ -103,8 +101,6 @@ def log( if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() - # sync across workers when using distributed training - sync_fn = sync_fn or sync_ddp_if_available if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() # TODO: Find a way to make the reduction only once, so we don't need to clone. From dbe634f5999b6b9d351487e6983a4604a7849402 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sun, 11 Apr 2021 02:28:39 +0530 Subject: [PATCH 02/11] Fix typo --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index fea4ae725b5e4..972ec275a5d6e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -106,7 +106,7 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None: self.precision_plugin.pre_dispatch() def post_dispatch(self, trainer: 'pl.Trainer') -> None: - """Hook to do something before the training/evaluation/prediction starts.""" + """Hook to do something after the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() From 53b6950635f6cdad63697f34e7f2433af0a4b0ae Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sun, 11 Apr 2021 20:30:19 +0530 Subject: [PATCH 03/11] add tpu distributed --- pytorch_lightning/core/step_result.py | 4 ++- pytorch_lightning/utilities/__init__.py | 3 --- pytorch_lightning/utilities/distributed.py | 29 +++++++++++----------- pytorch_lightning/utilities/imports.py | 4 +++ 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 4af9c0fdd24be..8cc174ab3aed2 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -21,6 +21,8 @@ from torch import Tensor from torchmetrics import Metric +from pytorch_lightning.utilities.distributed import tpu_distributed + class Result(Dict): @@ -104,7 +106,7 @@ def log( if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() # TODO: Find a way to make the reduction only once, so we don't need to clone. - if is_dist_initialized and isinstance(value, torch.Tensor): + if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor): value = value.clone() else: value = torch.tensor(value, device=device, dtype=torch.float) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 398e3782bef8a..200bad0ef07aa 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -56,9 +56,6 @@ _XLA_AVAILABLE, ) from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 -from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401 - -_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 018d83a93a7a9..a3208e122b79a 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -17,16 +17,14 @@ import warnings from functools import partial, wraps from typing import Any, Optional, Union -from pytorch_lightning.utilities.imports import ( - _TORCH_GREATER_EQUAL_1_8, - _TORCH_GREATER_EQUAL_1_9, -) import torch - from torch.nn.parallel.distributed import DistributedDataParallel -log = logging.getLogger(__name__) +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE + +if _TPU_AVAILABLE: + import torch_xla.core.xla_model as xm if torch.distributed.is_available(): from torch.distributed import group, ReduceOp @@ -40,6 +38,9 @@ class group: WORLD = None +log = logging.getLogger(__name__) + + def rank_zero_only(fn): @wraps(fn) @@ -294,19 +295,13 @@ def register_ddp_comm_hook( ) """ if not _TORCH_GREATER_EQUAL_1_8: - rank_zero_warn( - "Not registering DDP comm hook. " - "To use communication hooks, please use pytorch>=1.8.0." - ) + rank_zero_warn("Not registering DDP comm hook. " "To use communication hooks, please use pytorch>=1.8.0.") return if ddp_comm_hook is None: return if ddp_comm_wrapper is not None: if not _TORCH_GREATER_EQUAL_1_9: - rank_zero_warn( - "Not applying DDP comm wrapper. " - "To use communication wrapper, please use pytorch>=1.9.0." - ) + rank_zero_warn("Not applying DDP comm wrapper. " "To use communication wrapper, please use pytorch>=1.9.0.") else: rank_zero_info( f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." @@ -318,3 +313,9 @@ def register_ddp_comm_hook( state=ddp_comm_state, hook=ddp_comm_hook, ) + + +def tpu_distributed(): + if _TPU_AVAILABLE: + return xm.xrt_world_size() > 1 + return False diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 621e0d17d2ea7..e0b10a0a9eae1 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -88,3 +88,7 @@ def _compare_version(package: str, op, version) -> bool: _TORCHTEXT_AVAILABLE = _module_available("torchtext") _TORCHVISION_AVAILABLE = _module_available('torchvision') _XLA_AVAILABLE = _module_available("torch_xla") + +from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401 + +_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() From 6d404819c775d13a2cc4f69f39e760ae450255cd Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sun, 11 Apr 2021 21:49:45 +0530 Subject: [PATCH 04/11] Add TPU Available --- pytorch_lightning/utilities/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 200bad0ef07aa..3c1108b535f05 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -53,6 +53,7 @@ _TORCH_QUANTIZE_AVAILABLE, _TORCHTEXT_AVAILABLE, _TORCHVISION_AVAILABLE, + _TPU_AVAILABLE, _XLA_AVAILABLE, ) from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401 From af1df437e2f67614431eab127f691f99adff3d9f Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Sun, 11 Apr 2021 21:53:47 +0530 Subject: [PATCH 05/11] fix flake8 --- pytorch_lightning/utilities/imports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index e0b10a0a9eae1..5cfa1495c52ac 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -89,6 +89,6 @@ def _compare_version(package: str, op, version) -> bool: _TORCHVISION_AVAILABLE = _module_available('torchvision') _XLA_AVAILABLE = _module_available("torch_xla") -from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401 +from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402 _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() From 20760d652033b4a31f93e0bcb853bd134b67a5a6 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 12 Apr 2021 00:39:05 +0530 Subject: [PATCH 06/11] add sync ddp --- pytorch_lightning/core/step_result.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 8cc174ab3aed2..ade51b8775043 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -21,7 +21,7 @@ from torch import Tensor from torchmetrics import Metric -from pytorch_lightning.utilities.distributed import tpu_distributed +from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed class Result(Dict): @@ -103,6 +103,8 @@ def log( if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() + sync_fn = sync_fn or sync_ddp_if_available + if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized() # TODO: Find a way to make the reduction only once, so we don't need to clone. From 9ff17ecb98cc1cac6e56cf884a56761b176cffec Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 12 Apr 2021 01:08:03 +0530 Subject: [PATCH 07/11] Add test --- tests/models/test_tpu.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 6409f2ef4bcbf..8c7a47eeb379a 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -15,14 +15,17 @@ from argparse import ArgumentParser from unittest import mock +import numpy as np import pytest +import torch from torch.utils.data import DataLoader import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils -from pytorch_lightning import Trainer +from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping +from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import _TPU_AVAILABLE @@ -416,3 +419,19 @@ def test_if_test_works_with_checkpoint_false(tmpdir): trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" + + +@RunIf(tpu=True) +@pl_multi_process_test +def test_tpu_sync_dist(): + """Test tpu spawn sync dist operation """ + + def test_sync_dist(rank): + tensor = torch.tensor([1.0]) + + res = Result() + res.log("test_tensor", tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM) + + assert res["test_tensor"].item() == 1, "Result-Log does not work properly with TPU Spawn and Tensors" + + xmp.spawn(test_sync_dist, nprocs=8, start_method='fork') From 2c8f4a45925dc8656851ca8b066054e883113ef9 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 12 Apr 2021 01:15:12 +0530 Subject: [PATCH 08/11] fix --- pytorch_lightning/core/step_result.py | 1 + pytorch_lightning/plugins/training_type/tpu_spawn.py | 7 ++++--- tests/models/test_tpu.py | 3 +-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index ade51b8775043..7a193662b597b 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -103,6 +103,7 @@ def log( if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() + # sync across workers when using distributed training sync_fn = sync_fn or sync_ddp_if_available if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index b072a29c7fbc6..5bcfd093aef96 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -15,7 +15,7 @@ import os import re import time -from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union import torch import torch.multiprocessing as mp @@ -41,7 +41,6 @@ if _OMEGACONF_AVAILABLE: from omegaconf import DictConfig, ListConfig, OmegaConf - if TYPE_CHECKING: from torch.nn import Module from torch.utils.data import DataLoader @@ -278,4 +277,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra Return: A tensor of shape (world_size, batch, ...) """ - return xm.all_gather(tensor.unsqueeze(0)) + if isinstance(tensor, torch.Tensor) and tensor.dim() == 0: + tensor = tensor.unsqueeze(0) + return xm.all_gather(tensor) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 8c7a47eeb379a..77fd93cb74ecf 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -15,14 +15,13 @@ from argparse import ArgumentParser from unittest import mock -import numpy as np import pytest import torch from torch.utils.data import DataLoader import tests.helpers.pipelines as tpipes import tests.helpers.utils as tutils -from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning import Trainer from pytorch_lightning.accelerators import TPUAccelerator from pytorch_lightning.callbacks import EarlyStopping from pytorch_lightning.core.step_result import Result From b59d91ef6b16450b0c2d2382eb239f3a6982e65c Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 12 Apr 2021 01:41:40 +0530 Subject: [PATCH 09/11] Update test --- tests/models/test_tpu.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 77fd93cb74ecf..e623c480882a3 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -427,10 +427,17 @@ def test_tpu_sync_dist(): def test_sync_dist(rank): tensor = torch.tensor([1.0]) + training_type_plugin = TPUSpawnPlugin() res = Result() - res.log("test_tensor", tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM) - - assert res["test_tensor"].item() == 1, "Result-Log does not work properly with TPU Spawn and Tensors" + res.log( + "test_tensor", + tensor, + sync_fn=training_type_plugin.reduce, + sync_dist=True, + sync_dist_op=torch.distributed.ReduceOp.SUM + ) + + assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors" xmp.spawn(test_sync_dist, nprocs=8, start_method='fork') From 548cad429135abd3e49873d5c98ff568252323a0 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Mon, 12 Apr 2021 16:48:55 +0530 Subject: [PATCH 10/11] Update changelog --- CHANGELOG.md | 3 +++ pytorch_lightning/utilities/distributed.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0f96664db72b0..18f2f2c47fcd7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915)) +- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index a3208e122b79a..4f65c998cf2bb 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -295,13 +295,13 @@ def register_ddp_comm_hook( ) """ if not _TORCH_GREATER_EQUAL_1_8: - rank_zero_warn("Not registering DDP comm hook. " "To use communication hooks, please use pytorch>=1.8.0.") + rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.") return if ddp_comm_hook is None: return if ddp_comm_wrapper is not None: if not _TORCH_GREATER_EQUAL_1_9: - rank_zero_warn("Not applying DDP comm wrapper. " "To use communication wrapper, please use pytorch>=1.9.0.") + rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.") else: rank_zero_info( f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})." From 9a402c30e2dd85a614b8fd8441a63e99992731b9 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Tue, 13 Apr 2021 12:04:55 +0530 Subject: [PATCH 11/11] Update pytorch_lightning/utilities/distributed.py Co-authored-by: ananthsub --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 4f65c998cf2bb..a54d00a983d9e 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -315,7 +315,7 @@ def register_ddp_comm_hook( ) -def tpu_distributed(): +def tpu_distributed() -> bool: if _TPU_AVAILABLE: return xm.xrt_world_size() > 1 return False