Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Fix some bugs for TPU [skip ci] #5878

Merged
merged 37 commits into from
Feb 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
6c993c5
fixed for single tpu
lezwon Feb 4, 2021
18cee2a
fixed spawn
lezwon Feb 5, 2021
ec4636f
Merge branch 'accelerator-refactor-sharded' into accelerator-refactor…
tchaton Feb 6, 2021
027a151
fixed spawn
lezwon Feb 7, 2021
4f711e0
update
Feb 9, 2021
2d72415
update
Feb 9, 2021
a642b26
wip
Feb 9, 2021
5677ee7
Merge branch 'accelerator-refactor-sharded' into tpu_fix
tchaton Feb 9, 2021
1cff0a9
resolve bugs
tchaton Feb 9, 2021
369de6c
resolve bug
tchaton Feb 9, 2021
f4797aa
update on comment
tchaton Feb 9, 2021
7395e03
removed decorator
tchaton Feb 9, 2021
0b7aa2f
resolve comments
tchaton Feb 9, 2021
9355e40
set to 4
tchaton Feb 9, 2021
8a4925f
update
Feb 9, 2021
5f14189
update
Feb 10, 2021
69dafb6
need cleaning
Feb 10, 2021
b046ec5
update
Feb 10, 2021
e0dadda
update
tchaton Feb 10, 2021
0472b9d
update
tchaton Feb 10, 2021
5b3a381
resolve flake8
tchaton Feb 10, 2021
843667f
resolve bugs
Feb 10, 2021
be5711f
exclude broadcast
Feb 10, 2021
3927d39
resolve bugs
Feb 10, 2021
1ed9d26
change test
Feb 10, 2021
f7bf098
update
Feb 10, 2021
c2bc888
update
Feb 10, 2021
4c50ef3
skip if meet fails
Feb 10, 2021
24be82a
Merge branch 'accelerator-refactor-sharded' into tpu_fix
tchaton Feb 10, 2021
68474b7
properly raise trace
Feb 10, 2021
253ea99
Merge branch 'tpu_fix' of https://github.com/PyTorchLightning/pytorch…
Feb 10, 2021
aea078c
update
Feb 10, 2021
e092c64
add catch
Feb 10, 2021
a631273
wrap test
Feb 10, 2021
5e6a6a1
resolve typo
Feb 11, 2021
ffe820c
update
Feb 11, 2021
c250faa
typo
Feb 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dockers/tpu-tests/tpu_test_cases.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ local tputests = base.BaseTest {
command: utils.scriptCommand(
|||
cd pytorch-lightning
coverage run --source=pytorch_lightning -m pytest -v \
coverage run --source=pytorch_lightning -m pytest -v --capture=no \
pytorch_lightning/utilities/xla_device_utils.py \
tests/accelerators/legacy/test_tpu_backend.py \
tests/models/test_tpu.py
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def setup(self, trainer: "Trainer", model: LightningModule) -> None:
model: the model to train
"""
self.connect_training_type_plugin(self.training_type_plugin, model)
self.setup_optimizers(trainer, model)
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)

@property
Expand Down Expand Up @@ -306,7 +306,7 @@ def on_train_end(self) -> None:
"""Hook to do something at the end of the training"""
pass

def setup_optimizers(self, trainer: "Trainer", model: LightningModule):
def setup_optimizers(self, trainer: "Trainer"):
"""creates optimizers and schedulers

Args:
Expand All @@ -315,7 +315,7 @@ def setup_optimizers(self, trainer: "Trainer", model: LightningModule):
"""
if trainer.testing is True:
return
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model)
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(self.lightning_module)
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def on_tpu(self):

@property
def tpu_id(self):
if self.on_tpu:
if self.on_tpu and isinstance(self.tpu_cores, list):
return self.tpu_cores[0]

return None
Expand Down Expand Up @@ -373,7 +373,10 @@ def select_training_type_plugin(self):
elif self.use_horovod:
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
elif self.on_tpu:
plugin = SingleTPUPlugin(self.tpu_id)
if isinstance(self.tpu_cores, list):
plugin = SingleTPUPlugin(self.tpu_id)
else:
plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was just a temporary fix. I think TPUSpawnPlugin is meant to be called above within elif self.use_ddp block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can keep both for now.

else:
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{self.root_gpu}" if self.on_gpu else "cpu"))
return plugin
Expand Down
25 changes: 0 additions & 25 deletions pytorch_lightning/accelerators/legacy/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import io
import os
import re
from typing import Any, Callable, Optional, Union

import torch
Expand All @@ -31,7 +30,6 @@
rank_zero_only,
rank_zero_warn,
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _TPU_AVAILABLE:
Expand Down Expand Up @@ -307,29 +305,6 @@ def load_spawn_weights(self, original_model):

return loaded_model

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend not in ("ddp_spawn", "ddp_cpu", "tpu"):
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
state_dict = move_data_to_device(model.state_dict(), torch.device("cpu"))
atomic_save(state_dict, last_path)
mp_queue.put(last_path)

def broadcast(self, obj, src=0):
if self.trainer.tpu_id is not None:
# running on a single core
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,9 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
trainer,
)

accelerator_backend = trainer.accelerator_backend

if accelerator_backend.training_type_plugin.rpc_enabled:
if trainer.training_type_plugin.rpc_enabled:
# RPCPlugin manages saving all model states
accelerator_backend.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
else:
self._save_model(last_filepath, trainer, pl_module)
if (
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def log(
value = torch.tensor(value, device=device, dtype=torch.float)
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)

if value.device.type == "xla":
value = value.cpu()

if 'meta' not in self:
self.__setitem__('meta', {})

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/tpu_bfloat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ class TPUHalfPrecisionPlugin(PrecisionPlugin):

def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
os.environ["XLA_USE_BF16"] = str(1)
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
17 changes: 13 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,20 @@ def set_world_ranks(self, process_idx):
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes

@property
def mp_spawn_kwargs(self):
return {
"args": (self.lightning_module.trainer, self.mp_queue),
"nprocs": self.num_processes,
}

def start_training(self, trainer):
mp.spawn(self.new_process, nprocs=self.num_processes, args=(trainer, self.mp_queue))
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []

def start_testing(self, trainer):
mp.spawn(self.new_process, nprocs=self.num_processes, args=(trainer, self.mp_queue))
mp.spawn(self.new_process, **self.mp_spawn_kwargs)

def new_process(self, process_idx, trainer, mp_queue):
self.mp_queue = mp_queue
Expand Down Expand Up @@ -173,7 +180,6 @@ def pre_configure_ddp(self):
self._ddp_kwargs["find_unused_parameters"] = True

def configure_ddp(self):

self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
Expand All @@ -197,6 +203,9 @@ def determine_ddp_device_ids(self):
return None
return [self.root_device.index]

def on_save(self, checkpoint: dict) -> dict:
return checkpoint

def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path
Expand All @@ -209,7 +218,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing trainer through model -> trainer?
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
atomic_save(self.lightning_module.state_dict(), last_path)
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)

# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
Expand Down
32 changes: 30 additions & 2 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import io
import os
from typing import Optional
from typing import Optional, Union

import torch

from pytorch_lightning import LightningModule
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn

if _TPU_AVAILABLE:
Expand All @@ -15,7 +17,9 @@

class SingleTPUPlugin(SingleDevicePlugin):

def __init__(self, device: torch.device):
def __init__(self, device: Union[torch.device, int]):
if isinstance(device, int):
device = xm.xla_device(device)
super().__init__(device)

self.tpu_local_core_rank = 0
Expand All @@ -24,6 +28,14 @@ def __init__(self, device: torch.device):
def on_tpu(self) -> bool:
return True

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
self.model_to_device()
return self._model

def model_to_device(self) -> None:
self._model.to(self.root_device)

def pre_training(self) -> None:
if isinstance(self.device, int):
self.device = xm.xla_device(self.device)
Expand All @@ -37,3 +49,19 @@ def post_training(self) -> None:
if on_colab_kaggle():
rank_zero_warn("cleaning up... please do not interrupt")
self.save_spawn_weights(model)

def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
"""
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
model.trainer.save_checkpoint(path)
return path

def on_save(self, checkpoint: dict) -> dict:
"""
Move XLA tensors to CPU before saving
Recommended on XLA Guide:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
"""
return move_data_to_device(checkpoint, torch.device("cpu"))
Loading