Skip to content

Commit

Permalink
[WIP] ref: decoupled ddp, ddp spawn (finish 3733) (#3819)
Browse files Browse the repository at this point in the history
* ref: finish #3733

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* remove deprecated test

* Update pytorch_lightning/accelerators/ddp_backend.py

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* remove deprecated test

* remove deprecated test

* remove deprecated test

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
  • Loading branch information
williamFalcon and ananthsub authored Oct 3, 2020
1 parent 3903cf6 commit 35d1111
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 71 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def select_accelerator(self):
use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn"
use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu"

# ddp script mode uses the same flags as TE
# TODO: decouple from TE
if os.environ.get('PL_DDP_PID', False):
use_torchelastic_ddp = False

# choose the appropriate accelerator backend
if self.trainer.use_ddp2:
accelerator_backend = accelerators.DDP2Backend(self.trainer)
Expand Down
101 changes: 48 additions & 53 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License
import os
import torch
import torch.distributed as torch_distrib
import subprocess
import sys
from os.path import abspath
from time import sleep
from typing import Optional

import numpy as np
import torch
import torch.distributed as torch_distrib
import torch.distributed as dist


from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.distributed import find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.distributed.dist import LightningDistributed

Expand All @@ -47,17 +46,21 @@ def __init__(self, trainer):
super().__init__(trainer)
self.task_idx = None
self._has_spawned_children = False
self.interactive_ddp_procs = []
self.dist = LightningDistributed()

def setup(self, model):
# first track model
self.trainer.model = model

# start the other scripts
self._call_children_scripts()
if os.environ.get('PL_IN_DDP_SUBPROCESS', '0') != '1':
self._call_children_scripts()

def _call_children_scripts(self):
# set the task idx
self.task_idx = int(os.environ['PL_DDP_PID'])

def _call_children_scripts(self):
assert self.trainer.global_rank == 0
self._check_can_spawn_children()
self._has_spawned_children = True
Expand Down Expand Up @@ -104,11 +107,12 @@ def _call_children_scripts(self):

os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}'

self.trainer.interactive_ddp_procs = []
self.interactive_ddp_procs = []
for local_rank in range(1, self.trainer.num_processes):
env_copy = os.environ.copy()
env_copy['LOCAL_RANK'] = f'{local_rank}'
env_copy['PL_DDP_PID'] = str(self.trainer.data_parallel_device_ids[local_rank])
env_copy['PL_GLOBAL_SEED'] = os.environ.get('PL_GLOBAL_SEED')

# start process
# if hydra is available and initialized, make sure to set the cwd correctly
Expand All @@ -117,53 +121,22 @@ def _call_children_scripts(self):
if HydraConfig.initialized():
cwd = get_original_cwd()
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
self.trainer.interactive_ddp_procs.append(proc)
self.interactive_ddp_procs.append(proc)

# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

self.task_idx = 0
os.environ['PL_DDP_PID'] = str(0)

def train(self):
model = self.trainer.model
results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True)
del os.environ['WORLD_SIZE']
return results

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
"You tried to run `.fit` or `.test` multiple times in the same script."
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
)

def set_world_ranks(self, process_idx):
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx, is_master):
gpu_idx = process_idx

# when using ddp, the master process (proc 0) continues running as the main one
# this means that the local rank will always be 0
# (even if cuda visible devices has other visible gpus)
# this means that the master process needs to pull the 0th visible index as the device number
if is_master:
available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
gpu_idx = int(available_gpus[self.trainer.local_rank])

gpu_idx = int(os.environ.get('PL_DDP_PID', gpu_idx))

self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
device_ids = [self.trainer.root_gpu]
return device_ids
results = self.ddp_train(process_idx=self.task_idx, model=model)
if 'WORLD_SIZE' in os.environ:
del os.environ['WORLD_SIZE']
return results

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
Expand All @@ -185,17 +158,41 @@ def barrier(self, name: str = None):
if torch_distrib.is_initialized():
torch_distrib.barrier()

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
"You tried to run `.fit` or `.test` multiple times in the same script."
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
)

def set_world_ranks(self, process_idx):
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

def model_to_device(self, model, process_idx):
self.trainer.root_gpu = process_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

def get_device_ids(self):
device_ids = [self.trainer.root_gpu]
return device_ids

def on_train_end(self):
pass

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def broadcast(self, obj, src=0):
return self.dist.broadcast(obj)

def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
def ddp_train(self, process_idx, model):
"""
Entry point for ddp
Expand All @@ -211,9 +208,6 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
if seed is not None:
seed_everything(int(seed))

# offset the process id if requested
process_idx = process_idx + proc_offset

# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()
Expand Down Expand Up @@ -249,7 +243,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
model = model.configure_sync_batchnorm(model)

# move the model to the correct device
self.model_to_device(model, process_idx, is_master)
self.model_to_device(model, process_idx)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand All @@ -268,6 +262,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
model = model.configure_ddp(model, device_ids)

# set up training routine
self.barrier('ddp_setup')
self.trainer.train_loop.setup_training(model)

# train or test
Expand Down
36 changes: 18 additions & 18 deletions tests/backends/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,21 @@ def test_multi_gpu_model_ddp_test_only(tmpdir, cli_args):
assert result['status'] == 'complete'


# @pytest.mark.parametrize('cli_args', [
# pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
# ])
# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
# def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args):
# # call the script
# call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20)
#
# # load the results of the script
# result_path = os.path.join(tmpdir, 'ddp.result')
# result = torch.load(result_path)
#
# # verify the file wrote the expected outputs
# assert result['status'] == 'complete'
#
# model_outs = result['result']
# for out in model_outs:
# assert out['test_acc'] > 0.90
@pytest.mark.parametrize('cli_args', [
pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'),
])
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args):
# call the script
call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20)

# load the results of the script
result_path = os.path.join(tmpdir, 'ddp.result')
result = torch.load(result_path)

# verify the file wrote the expected outputs
assert result['status'] == 'complete'

model_outs = result['result']
for out in model_outs:
assert out['test_acc'] > 0.90

0 comments on commit 35d1111

Please sign in to comment.