Skip to content

Commit

Permalink
ref: decoupled ddp2 (#3816)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Oct 3, 2020
1 parent 6232063 commit 0838c6b
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 209 deletions.
124 changes: 108 additions & 16 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
import os

import torch

import torch.distributed as torch_distrib
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -29,15 +33,15 @@
HYDRA_AVAILABLE = True


class DDP2Backend(DDPBase):
class DDP2Backend(Accelerator):

def __init__(self, trainer):
super().__init__(trainer)
self.task_idx = None
self.dist = LightningDistributed()

def setup(self, model):
self._resolve_task_idx()

self.trainer.model = model

def _resolve_task_idx(self):
Expand All @@ -53,7 +57,27 @@ def _resolve_task_idx(self):

def train(self):
model = self.trainer.model
self.ddp_train_tmp(process_idx=self.task_idx, mp_queue=None, model=model)
return self.ddp_train(process_idx=self.task_idx, mp_queue=None, model=model)

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()

def training_step_end(self, output):
if isinstance(output, Result):
Expand All @@ -75,21 +99,89 @@ def set_world_ranks(self, process_idx):
self.trainer.global_rank = self.trainer.node_rank
self.trainer.world_size = self.trainer.num_nodes

def model_to_device(self, model, process_idx, is_master):
gpu_idx = process_idx
def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)

# when using ddp, the master process (proc 0) continues running as the main one
# this means that the local rank will always be 0
# (even if cuda visible devices has other visible gpus)
# this means that the master process needs to pull the 0th visible index as the device number
if is_master:
available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
gpu_idx = int(available_gpus[self.trainer.local_rank])

self.trainer.root_gpu = gpu_idx
def model_to_device(self, model, process_idx):
self.trainer.root_gpu = process_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
device_ids = self.trainer.data_parallel_device_ids
return device_ids

def ddp_train(self, process_idx, mp_queue, model):
"""
Entry point for ddp
Args:
process_idx:
mp_queue: multiprocessing queue
model:
Returns:
"""
# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# determine which process we are and world size
self.set_world_ranks(process_idx)

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero and not torch.distributed.is_initialized():
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

# device ids change depending on the DDP setup
device_ids = self.get_device_ids()

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

# clean up memory
torch.cuda.empty_cache()
return results
193 changes: 0 additions & 193 deletions pytorch_lightning/accelerators/ddp_base_backend.py

This file was deleted.

0 comments on commit 0838c6b

Please sign in to comment.