From e9becd4015e88b4d7b74a6aa65e1a372a7e4823d Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 1 Mar 2021 19:18:08 +0000 Subject: [PATCH 1/8] update --- pytorch_lightning/accelerators/tpu.py | 6 ++---- pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 +++ .../trainer/connectors/accelerator_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 5 +++-- tests/models/test_tpu.py | 5 ++--- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index bbadd571d1b92..8f98cb8ac5a20 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -31,10 +31,8 @@ def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None: raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.") return super().setup(trainer, model) - def run_optimizer_step( - self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any - ) -> None: - xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) + def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs): + xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """ diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 4232cba485414..74fa88dae059b 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -141,6 +141,9 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) xm.save(self.lightning_module.state_dict(), last_path) + # this barrier seems to make xm.save fails less often + self.barrier("rdz") + if self.global_rank == 0: # todo, pass complete checkpoint as state dictionary self.mp_queue.put(best_model_path) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 4f942f9b35e5d..0ba63fd3c7ddc 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -494,7 +494,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None): # define the max CPU available self.num_processes = os.cpu_count() # special case with TPUs - elif self.distributed_backend == 'tpu': + elif self.distributed_backend == 'tpu' or self.tpu_cores is not None: self._device_type = DeviceType.TPU elif self.distributed_backend and self._distrib_type is None: self._distrib_type = DistributedType(self.distributed_backend) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 68453811da203..2751728582dc7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -911,8 +911,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): f'specify a path for a checkpoint .test(ckpt_path=PATH)' ) return {} - if not self._device_type == DeviceType.TPU: - self.accelerator.barrier() + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 6a4605b3e2b36..a7e7f933d934e 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -177,8 +177,6 @@ def test_model_16bit_tpu_cores_8(tmpdir): def test_model_tpu_early_stop(tmpdir): """Test if single TPU core training works""" - # todo: Test on 8 cores - hanging. - class CustomBoringModel(BoringModel): def validation_step(self, *args, **kwargs): @@ -195,9 +193,10 @@ def validation_step(self, *args, **kwargs): max_epochs=2, limit_train_batches=2, limit_val_batches=2, - tpu_cores=[1], + tpu_cores=8, ) trainer.fit(model) + trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32)) @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") From 69be46c6cf5234d71cc61d1724c1df223c723934 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Mar 2021 19:19:54 +0000 Subject: [PATCH 2/8] resolve flake8 --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2751728582dc7..05d971c40d7aa 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -56,7 +56,7 @@ from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import DeviceType, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException From eb7a7627458bdd3ea2b6179e03c6f9f13948f97f Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 1 Mar 2021 19:27:05 +0000 Subject: [PATCH 3/8] update --- .../plugins/training_type/tpu_spawn.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 74fa88dae059b..ae4483b2c03ee 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -14,6 +14,7 @@ import io import os import re +from time import sleep from typing import Any, Dict, Iterable, List, Optional, Union import torch @@ -50,6 +51,7 @@ def __init__( ) self.tpu_local_core_rank = 0 self.start_method = None + self._repeat_save_on_fail = 3 def connect(self, model: torch.nn.Module) -> torch.nn.Module: self.create_mp_queue() @@ -139,10 +141,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing trainer through model -> trainer? if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - xm.save(self.lightning_module.state_dict(), last_path) - - # this barrier seems to make xm.save fails less often - self.barrier("rdz") + self.try_save(self.lightning_module.state_dict(), last_path) if self.global_rank == 0: # todo, pass complete checkpoint as state dictionary @@ -150,6 +149,16 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(last_path) self.mp_queue.put(results) + def try_save(self, state_dict: Dict, path: str): + # saving can randomly fail, + # therefore we try several times + for _ in range(self._repeat_save_on_fail): + try: + xm.save(state_dict, path) + break + except RuntimeError: + sleep(0.001) + def broadcast(self, obj: object, src: int = 0) -> object: buffer = io.BytesIO() torch.save(obj, buffer) @@ -297,4 +306,4 @@ def save_checkpoint(self, filepath, weights_only: bool = False): # dump states as a checkpoint dictionary object _checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only) # Todo: TypeError: 'mappingproxy' object does not support item assignment - xm.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) + self.try_save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) From f98008f345fc2a22bf0737b7701e408e1286dba4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Mar 2021 19:31:05 +0000 Subject: [PATCH 4/8] update --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index ae4483b2c03ee..15be79111c4ab 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -149,9 +149,11 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(last_path) self.mp_queue.put(results) - def try_save(self, state_dict: Dict, path: str): - # saving can randomly fail, - # therefore we try several times + def try_save(self, state_dict: Dict, path: str) -> None: + """ + Saving with xm.save can failed to meet rendez-vous. + Therefore, we will try several times to do so. + """ for _ in range(self._repeat_save_on_fail): try: xm.save(state_dict, path) From cf2833d97dde17bcf49d0a2aa9210517bb3338a0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 1 Mar 2021 19:35:06 +0000 Subject: [PATCH 5/8] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cf9b731c27fd..040b4391c86ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931)) +- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272)) + + ## [1.2.1] - 2021-02-23 ### Fixed From b49dceda7260002e6e302436bf19a6af98ca56ac Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 2 Mar 2021 09:01:42 +0000 Subject: [PATCH 6/8] update --- .../plugins/training_type/horovod.py | 4 ++- .../plugins/training_type/tpu_spawn.py | 29 ++++++++++--------- pytorch_lightning/trainer/trainer.py | 3 +- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index e940cb1d7229b..8fe52190fd7bb 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -15,6 +15,7 @@ from typing import Any, List, Optional, Union import torch +import torch.distributed as torch_distrib from torch.optim.lr_scheduler import _LRScheduler, Optimizer from pytorch_lightning.core.optimizer import LightningOptimizer @@ -116,7 +117,8 @@ def start_predicting(self, trainer): hvd.join() def barrier(self, *args, **kwargs): - hvd.join() + if torch_distrib.is_initialized(): + hvd.join() def broadcast(self, obj: object, src: int = 0) -> object: obj = hvd.broadcast_object(obj, src) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 15be79111c4ab..b56aab3fa41f1 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -18,6 +18,7 @@ from typing import Any, Dict, Iterable, List, Optional, Union import torch +import torch.distributed as torch_distrib import torch.multiprocessing as mp from pytorch_lightning.core.lightning import LightningModule @@ -51,7 +52,6 @@ def __init__( ) self.tpu_local_core_rank = 0 self.start_method = None - self._repeat_save_on_fail = 3 def connect(self, model: torch.nn.Module) -> torch.nn.Module: self.create_mp_queue() @@ -127,7 +127,8 @@ def model_to_device(self) -> None: self._model.to(xm.xla_device()) def barrier(self, name: Optional[str] = None) -> None: - rendezvous(f"pl.Trainer.{name}") + if torch_distrib.is_initialized(): + rendezvous(f"pl.Trainer.{name}") def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing callback through model -> trainer -> callback? @@ -141,7 +142,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): # TODO: is there a better way than accessing trainer through model -> trainer? if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.try_save(self.lightning_module.state_dict(), last_path) + self.save(self.lightning_module.state_dict(), last_path) if self.global_rank == 0: # todo, pass complete checkpoint as state dictionary @@ -149,17 +150,19 @@ def transfer_distrib_spawn_state_on_fit_end(self, results): self.mp_queue.put(last_path) self.mp_queue.put(results) - def try_save(self, state_dict: Dict, path: str) -> None: + def save(self, state_dict: Dict, path: str) -> None: """ - Saving with xm.save can failed to meet rendez-vous. - Therefore, we will try several times to do so. + Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``. + The rendez-vous doesn't affect directly saving. + We can ignore the ``RuntimeError`` to reduce friction with TPUs. """ - for _ in range(self._repeat_save_on_fail): - try: - xm.save(state_dict, path) - break - except RuntimeError: - sleep(0.001) + try: + xm.save(state_dict, path) + except RuntimeError as e: + if "Failed to meet rendezvous" in str(e): + pass + else: + raise e def broadcast(self, obj: object, src: int = 0) -> object: buffer = io.BytesIO() @@ -308,4 +311,4 @@ def save_checkpoint(self, filepath, weights_only: bool = False): # dump states as a checkpoint dictionary object _checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only) # Todo: TypeError: 'mappingproxy' object does not support item assignment - self.try_save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) + self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 05d971c40d7aa..7d9fdee263294 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -912,8 +912,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): ) return {} - if torch.distributed.is_available() and torch.distributed.is_initialized(): - self.training_type_plugin.barrier() + self.training_type_plugin.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) From 7df3aae8ee58592e20894dd76e463752422241af Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 2 Mar 2021 09:31:24 +0000 Subject: [PATCH 7/8] resolve flake8 --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index b56aab3fa41f1..98d164c0b8b4e 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -14,7 +14,6 @@ import io import os import re -from time import sleep from typing import Any, Dict, Iterable, List, Optional, Union import torch From 6f3a72f38352b888dfb68a26dd7f0eedf6e82cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 2 Mar 2021 18:47:57 +0100 Subject: [PATCH 8/8] Update pytorch_lightning/plugins/training_type/tpu_spawn.py --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 98d164c0b8b4e..9639a17e637bb 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -158,9 +158,7 @@ def save(self, state_dict: Dict, path: str) -> None: try: xm.save(state_dict, path) except RuntimeError as e: - if "Failed to meet rendezvous" in str(e): - pass - else: + if "Failed to meet rendezvous" not in str(e): raise e def broadcast(self, obj: object, src: int = 0) -> object: