Skip to content

Commit

Permalink
[wip] Fix some bugs for TPU [skip ci] (#5878)
Browse files Browse the repository at this point in the history
* fixed for single tpu

* fixed spawn

* fixed spawn

* update

* update

* wip

* resolve bugs

* resolve bug

* update on comment

* removed decorator

* resolve comments

* set to 4

* update

* update

* need cleaning

* update

* update

* update

* resolve flake8

* resolve bugs

* exclude broadcast

* resolve bugs

* change test

* update

* update

* skip if meet fails

* properly raise trace

* update

* add catch

* wrap test

* resolve typo

* update

* typo

Co-authored-by: Lezwon Castelino <lezwon@gmail.com>
Co-authored-by: Your Name <you@example.com>
  • Loading branch information
3 people authored Feb 11, 2021
1 parent 236009e commit aace276
Show file tree
Hide file tree
Showing 20 changed files with 201 additions and 108 deletions.
2 changes: 1 addition & 1 deletion dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ local tputests = base.BaseTest {
command: utils.scriptCommand(
|||
cd pytorch-lightning
coverage run --source=pytorch_lightning -m pytest -v \
coverage run --source=pytorch_lightning -m pytest -v --capture=no \
pytorch_lightning/utilities/xla_device_utils.py \
tests/accelerators/legacy/test_tpu_backend.py \
tests/models/test_tpu.py
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def setup(self, trainer: "Trainer", model: LightningModule) -> None:
model: the model to train
"""
self.connect_training_type_plugin(self.training_type_plugin, model)
self.setup_optimizers(trainer, model)
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)

@property
Expand Down Expand Up @@ -306,7 +306,7 @@ def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass

def setup_optimizers(self, trainer: "Trainer", model: LightningModule):
def setup_optimizers(self, trainer: "Trainer"):
"""creates optimizers and schedulers
Args:
Expand All @@ -315,7 +315,7 @@ def setup_optimizers(self, trainer: "Trainer", model: LightningModule):
"""
if trainer.testing is True:
return
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model)
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(self.lightning_module)
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def on_tpu(self):

@property
def tpu_id(self):
if self.on_tpu:
if self.on_tpu and isinstance(self.tpu_cores, list):
return self.tpu_cores[0]

return None
Expand Down Expand Up @@ -380,7 +380,10 @@ def select_training_type_plugin(self):
elif self.use_horovod:
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
elif self.on_tpu:
plugin = SingleTPUPlugin(self.tpu_id)
if isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
else:
plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores)))
else:
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
Expand Down
25 changes: 0 additions & 25 deletions pytorch_lightning/accelerators/legacy/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import io
import os
import re
from typing import Any, Callable, Optional, Union

import torch
Expand All @@ -31,7 +30,6 @@
rank_zero_only,
rank_zero_warn,
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TPU_AVAILABLE:
Expand Down Expand Up @@ -307,29 +305,6 @@ def load_spawn_weights(self, original_model):

return loaded_model

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend not in ("ddp_spawn", "ddp_cpu", "tpu"):
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
state_dict = move_data_to_device(model.state_dict(), torch.device("cpu"))
atomic_save(state_dict, last_path)
mp_queue.put(last_path)

def broadcast(self, obj, src=0):
if self.trainer.tpu_id is not None:
# running on a single core
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,9 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
trainer,
)

accelerator_backend = trainer.accelerator_backend

if accelerator_backend.training_type_plugin.rpc_enabled:
if trainer.training_type_plugin.rpc_enabled:
# RPCPlugin manages saving all model states
accelerator_backend.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
else:
self._save_model(last_filepath, trainer, pl_module)
if (
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def log(
value = torch.tensor(value, device=device, dtype=torch.float)
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

if value.device.type == "xla":
value = value.cpu()

if 'meta' not in self:
self.__setitem__('meta', {})

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/tpu_bfloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ class TPUHalfPrecisionPlugin(PrecisionPlugin):

def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
os.environ["XLA_USE_BF16"] = str(1)
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
17 changes: 13 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,20 @@ def set_world_ranks(self, process_idx):
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes

@property
def mp_spawn_kwargs(self):
return {
"args": (self.lightning_module.trainer, self.mp_queue),
"nprocs": self.num_processes,
}

def start_training(self, trainer):
mp.spawn(self.new_process, nprocs=self.num_processes, args=(trainer, self.mp_queue))
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []

def start_testing(self, trainer):
mp.spawn(self.new_process, nprocs=self.num_processes, args=(trainer, self.mp_queue))
mp.spawn(self.new_process, **self.mp_spawn_kwargs)

def new_process(self, process_idx, trainer, mp_queue):
self.mp_queue = mp_queue
Expand Down Expand Up @@ -173,7 +180,6 @@ def pre_configure_ddp(self):
self._ddp_kwargs["find_unused_parameters"] = True

def configure_ddp(self):

self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
Expand All @@ -197,6 +203,9 @@ def determine_ddp_device_ids(self):
return None
return [self.root_device.index]

def on_save(self, checkpoint: dict) -> dict:
return checkpoint

def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
Expand All @@ -210,7 +219,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing trainer through model -> trainer?
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
atomic_save(self.lightning_module.state_dict(), last_path)
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)

# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
Expand Down
32 changes: 30 additions & 2 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import io
import os
from typing import Optional
from typing import Optional, Union

import torch

from pytorch_lightning import LightningModule
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn

if _TPU_AVAILABLE:
Expand All @@ -15,7 +17,9 @@

class SingleTPUPlugin(SingleDevicePlugin):

def __init__(self, device: torch.device):
def __init__(self, device: Union[torch.device, int]):
if isinstance(device, int):
device = xm.xla_device(device)
super().__init__(device)

self.tpu_local_core_rank = 0
Expand All @@ -24,6 +28,14 @@ def __init__(self, device: torch.device):
def on_tpu(self) -> bool:
return True

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
self.model_to_device()
return self._model

def model_to_device(self) -> None:
self._model.to(self.root_device)

def pre_training(self) -> None:
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)
Expand All @@ -37,3 +49,19 @@ def post_training(self) -> None:
if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
"""
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
model.trainer.save_checkpoint(path)
return path

def on_save(self, checkpoint: dict) -> dict:
"""
Move XLA tensors to CPU before saving
Recommended on XLA Guide:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
"""
return move_data_to_device(checkpoint, torch.device("cpu"))
Loading

0 comments on commit aace276

Please sign in to comment.