Skip to content

Commit

Permalink
ref: accelerator connector methods x/n (#3470)
Browse files Browse the repository at this point in the history
williamFalcon authored Sep 12, 2020

Verified

This commit was signed with the committer’s verified signature.
thaJeztah Sebastiaan van Stijn
1 parent de99222 commit dd324e4
Showing 6 changed files with 115 additions and 270 deletions.
27 changes: 26 additions & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
# limitations under the License

import os
import re

import torch

@@ -24,6 +25,8 @@
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn

try:
from hydra.utils import to_absolute_path, get_original_cwd
@@ -162,7 +165,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
model = self.trainer.get_model()

# persist info in ddp_spawn
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
@@ -207,3 +210,25 @@ def early_stopping_should_stop(self, pl_module):
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
29 changes: 27 additions & 2 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
# limitations under the License

import os
import re
import subprocess
import sys
from os.path import abspath
@@ -24,10 +25,12 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn

try:
from hydra.utils import to_absolute_path, get_original_cwd
@@ -246,7 +249,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
model = self.trainer.get_model()

# persist info in ddp_spawn
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
@@ -286,3 +289,25 @@ def early_stopping_should_stop(self, pl_module):
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
27 changes: 26 additions & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
import re

import torch
import torch.multiprocessing as mp
@@ -22,6 +23,8 @@
from pytorch_lightning.accelerators.base_backend import Accelerator
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn

try:
from apex import amp
@@ -173,7 +176,7 @@ def ddp_train(self, process_idx, mp_queue, model):
model = self.trainer.get_model()

# persist info in ddp_spawn
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
@@ -203,3 +206,25 @@ def early_stopping_should_stop(self, pl_module):
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
if self.trainer.distributed_backend.lower() not in ['ddp_spawn', 'ddp_cpu', 'tpu']:
return

# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
38 changes: 36 additions & 2 deletions pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
@@ -98,7 +98,7 @@ def __load_weights_on_main_process(self):

# load weights if not interrupted
if self.trainer.on_colab_kaggle and not self.trainer.testing:
self.trainer.load_spawn_weights(model)
self.load_spawn_weights(model)

self.trainer.model = model

@@ -180,7 +180,7 @@ def __save_end_of_training_weights(self, model: LightningModule, trainer):
# when training ends on these platforms dump weights to get out of the main process
if trainer.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
trainer.save_spawn_weights(model)
self.save_spawn_weights(model)

def __setup_tpu_training(self, model: LightningModule, trainer):
# use the default device from the process
@@ -260,3 +260,37 @@ def early_stopping_should_stop(self, pl_module):
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
should_stop = int(stop.item()) == self.trainer.world_size
return should_stop

def save_spawn_weights(self, model):
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
:param model:
:return:
"""
if self.trainer.is_global_zero:
path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt')
self.trainer.save_checkpoint(path)
return path

def load_spawn_weights(self, original_model):
"""
Load the temp weights saved in the process
To recover the trained model from the ddp process we load the saved weights
:param model:
:return:
"""

loaded_model = original_model

if self.trainer.is_global_zero:
# load weights saved in ddp
path = os.path.join(self.trainer.default_root_dir, '__temp_weight_distributed_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)

# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())

# remove ddp weights
os.remove(path)

return loaded_model
262 changes: 0 additions & 262 deletions pytorch_lightning/trainer/distrib_data_parallel.py

This file was deleted.

2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,6 @@
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
@@ -76,7 +75,6 @@ class Trainer(
TrainerCallbackHookMixin,
TrainerModelHooksMixin,
TrainerOptimizersMixin,
TrainerDDPMixin,
TrainerLoggingMixin,
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,

0 comments on commit dd324e4

Please sign in to comment.