From 35d111199437406a5ee9b6fa4be0707b50f283aa Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 3 Oct 2020 14:05:31 -0400 Subject: [PATCH] [WIP] ref: decoupled ddp, ddp spawn (finish 3733) (#3819) * ref: finish #3733 * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * remove deprecated test * Update pytorch_lightning/accelerators/ddp_backend.py Co-authored-by: ananthsub * remove deprecated test * remove deprecated test * remove deprecated test Co-authored-by: ananthsub --- .../accelerators/accelerator_connector.py | 5 + pytorch_lightning/accelerators/ddp_backend.py | 101 +++++++++--------- tests/backends/test_ddp.py | 36 +++---- 3 files changed, 71 insertions(+), 71 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index ca0cf3fb91bb1..cd6d16e30aa61 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -137,6 +137,11 @@ def select_accelerator(self): use_ddp_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_spawn" use_ddp_cpu_spawn = self.trainer.use_ddp and self.trainer.distributed_backend == "ddp_cpu" + # ddp script mode uses the same flags as TE + # TODO: decouple from TE + if os.environ.get('PL_DDP_PID', False): + use_torchelastic_ddp = False + # choose the appropriate accelerator backend if self.trainer.use_ddp2: accelerator_backend = accelerators.DDP2Backend(self.trainer) diff --git a/pytorch_lightning/accelerators/ddp_backend.py b/pytorch_lightning/accelerators/ddp_backend.py index bedea6f058d18..f71b183d71a00 100644 --- a/pytorch_lightning/accelerators/ddp_backend.py +++ b/pytorch_lightning/accelerators/ddp_backend.py @@ -12,22 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License import os +import torch +import torch.distributed as torch_distrib import subprocess import sys from os.path import abspath from time import sleep from typing import Optional - import numpy as np -import torch -import torch.distributed as torch_distrib -import torch.distributed as dist + +from pytorch_lightning import _logger as log from pytorch_lightning.utilities.distributed import find_free_network_port from pytorch_lightning.accelerators.base_backend import Accelerator -from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.distributed.dist import LightningDistributed @@ -47,6 +46,7 @@ def __init__(self, trainer): super().__init__(trainer) self.task_idx = None self._has_spawned_children = False + self.interactive_ddp_procs = [] self.dist = LightningDistributed() def setup(self, model): @@ -54,10 +54,13 @@ def setup(self, model): self.trainer.model = model # start the other scripts - self._call_children_scripts() + if os.environ.get('PL_IN_DDP_SUBPROCESS', '0') != '1': + self._call_children_scripts() - def _call_children_scripts(self): + # set the task idx + self.task_idx = int(os.environ['PL_DDP_PID']) + def _call_children_scripts(self): assert self.trainer.global_rank == 0 self._check_can_spawn_children() self._has_spawned_children = True @@ -104,11 +107,12 @@ def _call_children_scripts(self): os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}' - self.trainer.interactive_ddp_procs = [] + self.interactive_ddp_procs = [] for local_rank in range(1, self.trainer.num_processes): env_copy = os.environ.copy() env_copy['LOCAL_RANK'] = f'{local_rank}' env_copy['PL_DDP_PID'] = str(self.trainer.data_parallel_device_ids[local_rank]) + env_copy['PL_GLOBAL_SEED'] = os.environ.get('PL_GLOBAL_SEED') # start process # if hydra is available and initialized, make sure to set the cwd correctly @@ -117,53 +121,22 @@ def _call_children_scripts(self): if HydraConfig.initialized(): cwd = get_original_cwd() proc = subprocess.Popen(command, env=env_copy, cwd=cwd) - self.trainer.interactive_ddp_procs.append(proc) + self.interactive_ddp_procs.append(proc) # starting all processes at once can cause issues # with dataloaders delay between 1-10 seconds delay = np.random.uniform(1, 5, 1)[0] sleep(delay) - self.task_idx = 0 + os.environ['PL_DDP_PID'] = str(0) def train(self): model = self.trainer.model - results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True) - del os.environ['WORLD_SIZE'] - return results - - def _check_can_spawn_children(self): - if self._has_spawned_children: - raise RuntimeError( - "You tried to run `.fit` or `.test` multiple times in the same script." - " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." - ) - - def set_world_ranks(self, process_idx): - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes - - def model_to_device(self, model, process_idx, is_master): - gpu_idx = process_idx - - # when using ddp, the master process (proc 0) continues running as the main one - # this means that the local rank will always be 0 - # (even if cuda visible devices has other visible gpus) - # this means that the master process needs to pull the 0th visible index as the device number - if is_master: - available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') - gpu_idx = int(available_gpus[self.trainer.local_rank]) - - gpu_idx = int(os.environ.get('PL_DDP_PID', gpu_idx)) - self.trainer.root_gpu = gpu_idx - torch.cuda.set_device(self.trainer.root_gpu) - model.cuda(self.trainer.root_gpu) - - def get_device_ids(self): - device_ids = [self.trainer.root_gpu] - return device_ids + results = self.ddp_train(process_idx=self.task_idx, model=model) + if 'WORLD_SIZE' in os.environ: + del os.environ['WORLD_SIZE'] + return results def training_step(self, args): if self.trainer.amp_backend == AMPType.NATIVE: @@ -185,17 +158,41 @@ def barrier(self, name: str = None): if torch_distrib.is_initialized(): torch_distrib.barrier() + def _check_can_spawn_children(self): + if self._has_spawned_children: + raise RuntimeError( + "You tried to run `.fit` or `.test` multiple times in the same script." + " This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." + ) + + def set_world_ranks(self, process_idx): + self.trainer.local_rank = process_idx + self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx + self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes + + def model_to_device(self, model, process_idx): + self.trainer.root_gpu = process_idx + torch.cuda.set_device(self.trainer.root_gpu) + model.cuda(self.trainer.root_gpu) + + def get_device_ids(self): + device_ids = [self.trainer.root_gpu] + return device_ids + + def on_train_end(self): + pass + def early_stopping_should_stop(self, pl_module): stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device) - dist.all_reduce(stop, op=dist.reduce_op.SUM) - dist.barrier() + torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM) + torch_distrib.barrier() should_stop = stop == self.trainer.world_size return should_stop def broadcast(self, obj, src=0): return self.dist.broadcast(obj) - def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): + def ddp_train(self, process_idx, model): """ Entry point for ddp @@ -211,9 +208,6 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): if seed is not None: seed_everything(int(seed)) - # offset the process id if requested - process_idx = process_idx + proc_offset - # show progressbar only on progress_rank 0 if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: self.trainer.progress_bar_callback.disable() @@ -249,7 +243,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): model = model.configure_sync_batchnorm(model) # move the model to the correct device - self.model_to_device(model, process_idx, is_master) + self.model_to_device(model, process_idx) # CHOOSE OPTIMIZER # allow for lr schedulers as well @@ -268,6 +262,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0): model = model.configure_ddp(model, device_ids) # set up training routine + self.barrier('ddp_setup') self.trainer.train_loop.setup_training(model) # train or test diff --git a/tests/backends/test_ddp.py b/tests/backends/test_ddp.py index 91f22c4d7c59d..08f9a7e109808 100644 --- a/tests/backends/test_ddp.py +++ b/tests/backends/test_ddp.py @@ -37,21 +37,21 @@ def test_multi_gpu_model_ddp_test_only(tmpdir, cli_args): assert result['status'] == 'complete' -# @pytest.mark.parametrize('cli_args', [ -# pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'), -# ]) -# @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -# def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args): -# # call the script -# call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20) -# -# # load the results of the script -# result_path = os.path.join(tmpdir, 'ddp.result') -# result = torch.load(result_path) -# -# # verify the file wrote the expected outputs -# assert result['status'] == 'complete' -# -# model_outs = result['result'] -# for out in model_outs: -# assert out['test_acc'] > 0.90 +@pytest.mark.parametrize('cli_args', [ + pytest.param('--max_epochs 1 --gpus 2 --distributed_backend ddp'), +]) +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_multi_gpu_model_ddp_fit_test(tmpdir, cli_args): + # call the script + call_training_script(ddp_model, cli_args, 'fit_test', tmpdir, timeout=20) + + # load the results of the script + result_path = os.path.join(tmpdir, 'ddp.result') + result = torch.load(result_path) + + # verify the file wrote the expected outputs + assert result['status'] == 'complete' + + model_outs = result['result'] + for out in model_outs: + assert out['test_acc'] > 0.90