Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: accelerator connector methods x/n #3470

Merged
merged 1 commit into from
Sep 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License

import os
import re

import torch

Expand All @@ -24,6 +25,8 @@
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -162,7 +165,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
model = self.trainer.get_model()

# persist info in ddp_spawn
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
Expand Down Expand Up @@ -207,3 +210,25 @@ def early_stopping_should_stop(self, pl_module):
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
29 changes: 27 additions & 2 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License

import os
import re
import subprocess
import sys
from os.path import abspath
Expand All @@ -24,10 +25,12 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -246,7 +249,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
model = self.trainer.get_model()

# persist info in ddp_spawn
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
Expand Down Expand Up @@ -286,3 +289,25 @@ def early_stopping_should_stop(self, pl_module):
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
27 changes: 26 additions & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
import re

import torch
import torch.multiprocessing as mp
Expand All @@ -22,6 +23,8 @@
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn

try:
from apex import amp
Expand Down Expand Up @@ -173,7 +176,7 @@ def ddp_train(self, process_idx, mp_queue, model):
model = self.trainer.get_model()

# persist info in ddp_spawn
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
Expand Down Expand Up @@ -203,3 +206,25 @@ def early_stopping_should_stop(self, pl_module):
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
38 changes: 36 additions & 2 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __load_weights_on_main_process(self):

# load weights if not interrupted
if self.trainer.on_colab_kaggle and not self.trainer.testing:
self.trainer.load_spawn_weights(model)
self.load_spawn_weights(model)

self.trainer.model = model

Expand Down Expand Up @@ -180,7 +180,7 @@ def __save_end_of_training_weights(self, model: LightningModule, trainer):
# when training ends on these platforms dump weights to get out of the main process
if trainer.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
trainer.save_spawn_weights(model)
self.save_spawn_weights(model)

def __setup_tpu_training(self, model: LightningModule, trainer):
# use the default device from the process
Expand Down Expand Up @@ -260,3 +260,37 @@ def early_stopping_should_stop(self, pl_module):
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
should_stop = int(stop.item()) == self.trainer.world_size
return should_stop

def save_spawn_weights(self, model):
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
:param model:
:return:
"""
if self.trainer.is_global_zero:
path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt')
self.trainer.save_checkpoint(path)
return path

def load_spawn_weights(self, original_model):
"""
Load the temp weights saved in the process
To recover the trained model from the ddp process we load the saved weights
:param model:
:return:
"""

loaded_model = original_model

if self.trainer.is_global_zero:
# load weights saved in ddp
path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)

# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())

# remove ddp weights
os.remove(path)

return loaded_model
Loading