diff --git a/pytorch_lightning/accelerators/__init__.py b/pytorch_lightning/accelerators/__init__.py index 22fdb164e0a61..178e54fb897f4 100644 --- a/pytorch_lightning/accelerators/__init__.py +++ b/pytorch_lightning/accelerators/__init__.py @@ -2,6 +2,7 @@ from pytorch_lightning.accelerators.ddp2_backend import DDP2Backend from pytorch_lightning.accelerators.ddp_backend import DDPBackend from pytorch_lightning.accelerators.ddp_spawn_backend import DDPSpawnBackend +from pytorch_lightning.accelerators.ddp_cpu_spawn_backend import DDPCPUSpawnBackend from pytorch_lightning.accelerators.dp_backend import DataParallelBackend from pytorch_lightning.accelerators.gpu_backend import GPUBackend from pytorch_lightning.accelerators.tpu_backend import TPUBackend diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index ec50788534622..2ff92ecd76c23 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -129,7 +129,8 @@ def select_accelerator(self): te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ) use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed - use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend in ['ddp_cpu', 'ddp_spawn'] + use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == 'ddp_spawn' + use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == 'ddp_cpu' # choose the appropriate accelerator backend if self.trainer.use_ddp2: @@ -144,6 +145,9 @@ def select_accelerator(self): elif use_ddp_spawn: accelerator_backend = accelerators.DDPSpawnBackend(self.trainer, nprocs=self.trainer.num_processes) + elif use_ddp_cpu_spawn: + accelerator_backend = accelerators.DDPCPUSpawnBackend(self.trainer, nprocs=self.trainer.num_processes) + elif self.trainer.distributed_backend == 'ddp': accelerator_backend = accelerators.DDPBackend(self.trainer, mode='ddp') diff --git a/pytorch_lightning/accelerators/ddp_base_backend.py b/pytorch_lightning/accelerators/ddp_base_backend.py index ed00a722cc3ce..c54b93ae1d202 100644 --- a/pytorch_lightning/accelerators/ddp_base_backend.py +++ b/pytorch_lightning/accelerators/ddp_base_backend.py @@ -20,7 +20,8 @@ import torch.distributed as torch_distrib import torch.distributed as dist from pytorch_lightning.utilities.cloud_io import atomic_save -from pytorch_lightning.utilities.distributed import rank_zero_warn +from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only +from pytorch_lightning import _logger as log try: from hydra.utils import to_absolute_path, get_original_cwd @@ -88,3 +89,102 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path) atomic_save(model.state_dict(), last_path) mp_queue.put(last_path) + + def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offset=0): + """ + Entry point for ddp + + Args: + process_idx: + mp_queue: multiprocessing queue + model: + + Returns: + + """ + # offset the process id if requested + process_idx = process_idx + proc_offset + + # show progressbar only on progress_rank 0 + if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: + self.trainer.progress_bar_callback.disable() + + # determine which process we are and world size + self.set_world_ranks(process_idx) + + # set warning rank + rank_zero_only.rank = self.trainer.global_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self.trainer + model.init_ddp_connection( + self.trainer.global_rank, + self.trainer.world_size, + self.trainer.is_slurm_managing_tasks + ) + + # call setup after the ddp process has connected + self.trainer.call_setup_hook(model) + + # on world_size=0 let everyone know training is starting + if self.trainer.is_global_zero: + log.info('-' * 100) + log.info(f'distributed_backend={self.trainer.distributed_backend}') + log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') + log.info('-' * 100) + + # call sync_bn before .cuda(), configure_apex and configure_ddp + if self.trainer.sync_batchnorm: + model = model.configure_sync_batchnorm(model) + + # move the model to the correct device + self.model_to_device(model, process_idx) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) + self.trainer.optimizers = optimizers + self.trainer.lr_schedulers = lr_schedulers + self.trainer.optimizer_frequencies = optimizer_frequencies + + # set model properties before going into wrapper + self.trainer.model_connector.copy_trainer_model_properties(model) + + # AMP - + # run through amp wrapper before going to distributed DP + if self.trainer.amp_backend == AMPType.APEX: + model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) + self.trainer.optimizers = optimizers + self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) + + # device ids change depending on the DDP setup + device_ids = self.get_device_ids() + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # set up training routine + self.trainer.train_loop.setup_training(model) + + # train or test + results = self.train_or_test() + + # get original model + model = self.trainer.get_model() + + # persist info in ddp_spawn + self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) + + # clean up memory + torch.cuda.empty_cache() + + def set_world_ranks(self, process_idx): + raise NotImplementedError('to create a ddp backend, please implement set_world_ranks') + + def model_to_device(self, model, process_idx): + raise NotImplementedError('to create a ddp backend, please implement model_to_device') + + def get_device_ids(self): + raise NotImplementedError('to create a ddp backend, please implement get_device_ids') diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py new file mode 100644 index 0000000000000..c59401929a8e2 --- /dev/null +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py @@ -0,0 +1,82 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import os + +import torch +import torch.multiprocessing as mp + +from pytorch_lightning.utilities.distributed import find_free_network_port +from pytorch_lightning.accelerators.ddp_base_backend import DDPBase + +try: + from apex import amp +except ImportError: + amp = None + + +class DDPCPUSpawnBackend(DDPBase): + + def __init__(self, trainer, nprocs): + super().__init__(trainer) + self.mp_queue = None + self.nprocs = nprocs + + def setup(self, model): + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port())) + + # pass in a state q + smp = mp.get_context('spawn') + self.mp_queue = smp.SimpleQueue() + + self.trainer.model = model + + def train(self): + model = self.trainer.model + + # train in children process + mp.spawn(self.ddp_train_tmp, nprocs=self.nprocs, args=(self.mp_queue, model,)) + + # restore main state with best weights + best_path = self.mp_queue.get() + results = self.mp_queue.get() + last_path = self.mp_queue.get() + + # recover the weights of the processes trained in the children + self.__recover_child_process_weights(model, best_path, last_path) + return results + + def __recover_child_process_weights(self, model, best_path, last_path): + # transfer back the best path to the trainer + if self.trainer.checkpoint_callback: + self.trainer.checkpoint_callback.best_model_path = best_path + # todo, pass also best score + + # load last weights + if last_path is not None and not self.trainer.testing: + ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt) + + self.trainer.model = model + + def set_world_ranks(self, process_idx): + self.trainer.local_rank = process_idx + self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx + self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes + + def model_to_device(self, model, process_idx): + pass + + def get_device_ids(self): + device_ids = None + return device_ids diff --git a/pytorch_lightning/accelerators/ddp_spawn_backend.py b/pytorch_lightning/accelerators/ddp_spawn_backend.py index 27c02ca46c3b3..bb12909be567f 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_spawn_backend.py @@ -16,9 +16,7 @@ import torch import torch.multiprocessing as mp -from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port +from pytorch_lightning.utilities.distributed import find_free_network_port from pytorch_lightning.accelerators.ddp_base_backend import DDPBase try: @@ -47,7 +45,7 @@ def train(self): model = self.trainer.model # train in children process - mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,)) + mp.spawn(self.ddp_train_tmp, nprocs=self.nprocs, args=(self.mp_queue, model,)) # restore main state with best weights best_path = self.mp_queue.get() @@ -71,107 +69,17 @@ def __recover_child_process_weights(self, model, best_path, last_path): self.trainer.model = model - def ddp_train(self, process_idx, mp_queue, model): - """ - Entry point for ddp - - Args: - process_idx: - mp_queue: multiprocessing queue - model: - - Returns: - - """ - # show progressbar only on progress_rank 0 - if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: - self.trainer.progress_bar_callback.disable() - - # determine which process we are and world size - if self.trainer.use_ddp: - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - elif self.trainer.use_ddp2: - self.trainer.local_rank = self.trainer.node_rank - self.trainer.global_rank = self.trainer.node_rank - self.trainer.world_size = self.trainer.num_nodes - - # set warning rank - rank_zero_only.rank = self.trainer.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table - model.trainer = self.trainer - model.init_ddp_connection( - self.trainer.global_rank, - self.trainer.world_size, - self.trainer.is_slurm_managing_tasks - ) - - # call setup after the ddp process has connected - self.trainer.call_setup_hook(model) - - # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero: - log.info('-' * 100) - log.info(f'distributed_backend={self.trainer.distributed_backend}') - log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') - log.info('-' * 100) - - # call sync_bn before .cuda(), configure_apex and configure_ddp - if self.trainer.sync_batchnorm: - model = model.configure_sync_batchnorm(model) - - # MODEL - # copy model to each gpu - if self.trainer.on_gpu: - gpu_idx = process_idx - self.trainer.root_gpu = gpu_idx - torch.cuda.set_device(self.trainer.root_gpu) - model.cuda(self.trainer.root_gpu) - - # CHOOSE OPTIMIZER - # allow for lr schedulers as well - optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) - self.trainer.optimizers = optimizers - self.trainer.lr_schedulers = lr_schedulers - self.trainer.optimizer_frequencies = optimizer_frequencies - - # set model properties before going into wrapper - self.trainer.model_connector.copy_trainer_model_properties(model) - - # AMP - - # run through amp wrapper before going to distributed DP - if self.trainer.amp_backend == AMPType.APEX: - model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) - self.trainer.optimizers = optimizers - self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) - - # DDP2 uses all GPUs on the machine - if self.trainer.distributed_backend == 'ddp' or self.trainer.distributed_backend == 'ddp_spawn': - device_ids = [self.trainer.root_gpu] - elif self.trainer.use_ddp2: - device_ids = self.trainer.data_parallel_device_ids - else: # includes ddp_cpu - device_ids = None - - # allow user to configure ddp - model = model.configure_ddp(model, device_ids) - - # set up training routine - self.trainer.train_loop.setup_training(model) - - # train or test - results = self.train_or_test() - - # get original model - model = self.trainer.get_model() - - # persist info in ddp_spawn - self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) - - # clean up memory - torch.cuda.empty_cache() + def set_world_ranks(self, process_idx): + self.trainer.local_rank = process_idx + self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx + self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes + + def model_to_device(self, model, process_idx): + gpu_idx = process_idx + self.trainer.root_gpu = gpu_idx + torch.cuda.set_device(self.trainer.root_gpu) + model.cuda(self.trainer.root_gpu) + + def get_device_ids(self): + device_ids = [self.trainer.root_gpu] + return device_ids