Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] TPU test hangs to barrier on 1 process #6272

Merged
merged 10 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))


- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))


Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, List, Optional, Union

import torch
import torch.distributed as torch_distrib
from torch.optim.lr_scheduler import _LRScheduler, Optimizer

from pytorch_lightning.core.optimizer import LightningOptimizer
Expand Down Expand Up @@ -116,7 +117,8 @@ def start_predicting(self, trainer):
hvd.join()

def barrier(self, *args, **kwargs):
hvd.join()
if torch_distrib.is_initialized():
hvd.join()

def broadcast(self, obj: object, src: int = 0) -> object:
obj = hvd.broadcast_object(obj, src)
Expand Down
22 changes: 19 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
import torch.distributed as torch_distrib
import torch.multiprocessing as mp

from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -125,7 +126,8 @@ def model_to_device(self) -> None:
self._model.to(xm.xla_device())

def barrier(self, name: Optional[str] = None) -> None:
rendezvous(f"pl.Trainer.{name}")
if torch_distrib.is_initialized():
rendezvous(f"pl.Trainer.{name}")

def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
Expand All @@ -139,14 +141,28 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing trainer through model -> trainer?
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
xm.save(self.lightning_module.state_dict(), last_path)
self.save(self.lightning_module.state_dict(), last_path)

if self.global_rank == 0:
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)

def save(self, state_dict: Dict, path: str) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall this be rather protected?

Suggested change
def save(self, state_dict: Dict, path: str) -> None:
def _save(self, state_dict: Dict, path: str) -> None:

"""
Saving with ``xm.save`` can be unstable and miss the rendez-vous after ``torch.save``.
The rendez-vous doesn't affect directly saving.
We can ignore the ``RuntimeError`` to reduce friction with TPUs.
"""
try:
xm.save(state_dict, path)
except RuntimeError as e:
if "Failed to meet rendezvous" in str(e):
pass
else:
raise e
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def broadcast(self, obj: object, src: int = 0) -> object:
buffer = io.BytesIO()
torch.save(obj, buffer)
Expand Down Expand Up @@ -294,4 +310,4 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
# dump states as a checkpoint dictionary object
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
xm.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
# define the max CPU available
self.num_processes = os.cpu_count()
# special case with TPUs
elif self.distributed_backend == 'tpu':
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
self._device_type = DeviceType.TPU
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -912,8 +912,8 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
)
return {}
if not self._device_type == DeviceType.TPU:
self.accelerator.barrier()

self.training_type_plugin.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ def test_model_16bit_tpu_cores_8(tmpdir):
def test_model_tpu_early_stop(tmpdir):
"""Test if single TPU core training works"""

# todo: Test on 8 cores - hanging.

class CustomBoringModel(BoringModel):

def validation_step(self, *args, **kwargs):
Expand All @@ -195,9 +193,10 @@ def validation_step(self, *args, **kwargs):
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
tpu_cores=[1],
tpu_cores=8,
)
trainer.fit(model)
trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))
tchaton marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
Expand Down