Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] TPU test hangs to barrier on 1 process #6272

Merged
merged 10 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Prevent `WandbLogger` from dropping values ([#5931](https://github.com/PyTorchLightning/pytorch-lightning/pull/5931))


- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
return super().setup(trainer, model)

def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Expand Down
18 changes: 16 additions & 2 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io
import os
import re
from time import sleep
from typing import Any, Dict, Iterable, List, Optional, Union

import torch
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
)
self.tpu_local_core_rank = 0
self.start_method = None
self._repeat_save_on_fail = 3

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self.create_mp_queue()
Expand Down Expand Up @@ -139,14 +141,26 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing trainer through model -> trainer?
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
xm.save(self.lightning_module.state_dict(), last_path)
self.try_save(self.lightning_module.state_dict(), last_path)

if self.global_rank == 0:
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)

def try_save(self, state_dict: Dict, path: str) -> None:
"""
Saving with xm.save can failed to meet rendez-vous.
Therefore, we will try several times to do so.
"""
for _ in range(self._repeat_save_on_fail):
try:
xm.save(state_dict, path)
break
except RuntimeError:
sleep(0.001)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def broadcast(self, obj: object, src: int = 0) -> object:
buffer = io.BytesIO()
torch.save(obj, buffer)
Expand Down Expand Up @@ -294,4 +308,4 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
# dump states as a checkpoint dictionary object
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
# Todo: TypeError: 'mappingproxy' object does not support item assignment
xm.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
self.try_save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
# define the max CPU available
self.num_processes = os.cpu_count()
# special case with TPUs
elif self.distributed_backend == 'tpu':
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
self._device_type = DeviceType.TPU
elif self.distributed_backend and self._distrib_type is None:
self._distrib_type = DistributedType(self.distributed_backend)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -911,8 +911,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
)
return {}
if not self._device_type == DeviceType.TPU:
self.accelerator.barrier()

if torch.distributed.is_available() and torch.distributed.is_initialized():
self.training_type_plugin.barrier()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ def test_model_16bit_tpu_cores_8(tmpdir):
def test_model_tpu_early_stop(tmpdir):
"""Test if single TPU core training works"""

# todo: Test on 8 cores - hanging.

class CustomBoringModel(BoringModel):

def validation_step(self, *args, **kwargs):
Expand All @@ -195,9 +193,10 @@ def validation_step(self, *args, **kwargs):
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
tpu_cores=[1],
tpu_cores=8,
)
trainer.fit(model)
trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))
tchaton marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
Expand Down