Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: merge backends x/n #3478

Merged
merged 4 commits into from
Sep 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pytorch_lightning.accelerators.ddp2_backend import DDP2Backend
from pytorch_lightning.accelerators.ddp_backend import DDPBackend
from pytorch_lightning.accelerators.ddp_spawn_backend import DDPSpawnBackend
from pytorch_lightning.accelerators.ddp_cpu_spawn_backend import DDPCPUSpawnBackend
from pytorch_lightning.accelerators.dp_backend import DataParallelBackend
from pytorch_lightning.accelerators.gpu_backend import GPUBackend
from pytorch_lightning.accelerators.tpu_backend import TPUBackend
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def select_accelerator(self):
te_flags_passed = 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ)
use_torchelastic_ddp = self.trainer.use_ddp and te_flags_passed

use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend in ['ddp_cpu', 'ddp_spawn']
use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == 'ddp_spawn'
use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == 'ddp_cpu'

# choose the appropriate accelerator backend
if self.trainer.use_ddp2:
Expand All @@ -144,6 +145,9 @@ def select_accelerator(self):
elif use_ddp_spawn:
accelerator_backend = accelerators.DDPSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)

elif use_ddp_cpu_spawn:
accelerator_backend = accelerators.DDPCPUSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)

elif self.trainer.distributed_backend == 'ddp':
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='ddp')

Expand Down
102 changes: 101 additions & 1 deletion pytorch_lightning/accelerators/ddp_base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import torch.distributed as torch_distrib
import torch.distributed as dist
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_warn
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
from pytorch_lightning import _logger as log

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand Down Expand Up @@ -88,3 +89,102 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)

def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
"""
Entry point for ddp

Args:
process_idx:
mp_queue: multiprocessing queue
model:

Returns:

"""
# offset the process id if requested
process_idx = process_idx + proc_offset

# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# determine which process we are and world size
self.set_world_ranks(process_idx)

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# AMP -
# run through amp wrapper before going to distributed DP
if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

# get original model
model = self.trainer.get_model()

# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()

def set_world_ranks(self, process_idx):
raise NotImplementedError('to create a ddp backend, please implement set_world_ranks')

def model_to_device(self, model, process_idx):
raise NotImplementedError('to create a ddp backend, please implement model_to_device')

def get_device_ids(self):
raise NotImplementedError('to create a ddp backend, please implement get_device_ids')
82 changes: 82 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import os

import torch
import torch.multiprocessing as mp

from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase

try:
from apex import amp
except ImportError:
amp = None


class DDPCPUSpawnBackend(DDPBase):

def __init__(self, trainer, nprocs):
super().__init__(trainer)
self.mp_queue = None
self.nprocs = nprocs

def setup(self, model):
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))

# pass in a state q
smp = mp.get_context('spawn')
self.mp_queue = smp.SimpleQueue()

self.trainer.model = model

def train(self):
model = self.trainer.model

# train in children process
mp.spawn(self.ddp_train_tmp, nprocs=self.nprocs, args=(self.mp_queue, model,))

# restore main state with best weights
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()

# recover the weights of the processes trained in the children
self.__recover_child_process_weights(model, best_path, last_path)
return results

def __recover_child_process_weights(self, model, best_path, last_path):
# transfer back the best path to the trainer
if self.trainer.checkpoint_callback:
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score

# load last weights
if last_path is not None and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self.trainer.model = model

def set_world_ranks(self, process_idx):
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
pass

def get_device_ids(self):
device_ids = None
return device_ids
124 changes: 16 additions & 108 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
import torch
import torch.multiprocessing as mp

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase

try:
Expand Down Expand Up @@ -47,7 +45,7 @@ def train(self):
model = self.trainer.model

# train in children process
mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,))
mp.spawn(self.ddp_train_tmp, nprocs=self.nprocs, args=(self.mp_queue, model,))

# restore main state with best weights
best_path = self.mp_queue.get()
Expand All @@ -71,107 +69,17 @@ def __recover_child_process_weights(self, model, best_path, last_path):

self.trainer.model = model

def ddp_train(self, process_idx, mp_queue, model):
"""
Entry point for ddp

Args:
process_idx:
mp_queue: multiprocessing queue
model:

Returns:

"""
# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# determine which process we are and world size
if self.trainer.use_ddp:
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

elif self.trainer.use_ddp2:
self.trainer.local_rank = self.trainer.node_rank
self.trainer.global_rank = self.trainer.node_rank
self.trainer.world_size = self.trainer.num_nodes

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# MODEL
# copy model to each gpu
if self.trainer.on_gpu:
gpu_idx = process_idx
self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# AMP -
# run through amp wrapper before going to distributed DP
if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# DDP2 uses all GPUs on the machine
if self.trainer.distributed_backend == 'ddp' or self.trainer.distributed_backend == 'ddp_spawn':
device_ids = [self.trainer.root_gpu]
elif self.trainer.use_ddp2:
device_ids = self.trainer.data_parallel_device_ids
else: # includes ddp_cpu
device_ids = None

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

# get original model
model = self.trainer.get_model()

# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()
def set_world_ranks(self, process_idx):
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
gpu_idx = process_idx
self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
device_ids = [self.trainer.root_gpu]
return device_ids