-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replaces ddp .spawn with subprocess #2029
Changes from all commits
7ec6aea
38f3d8e
689f53b
7368f4d
241dd59
c5667c7
42def65
3498ca7
9249cfd
be3420a
d5f4aa9
9e698e0
3b6fb44
824323f
74bc48e
6ca43ff
df1e9bb
e97f6ca
4e00774
eb701da
98def93
fac94b0
ce52526
1267279
f3675c2
39fa7f5
cd10531
bdba97c
3c3694e
0d60400
83aaa53
ae2da5a
e01f641
239d81c
8307539
e697f26
b7fb0e9
a4273d8
a054d0f
6c8d178
1e40ea8
a3cfb04
4bdd624
838ca06
7564256
cc9391a
9181d59
de61e1a
28768bb
8249366
5bd82d6
d434ab3
d9512b0
40361b8
c825334
3c69a45
d90f8f5
acff33a
c375b59
e11087e
abfb2ff
2acdabf
a7f9ea6
af8d8b8
682e342
0d4909e
69a5f1a
7708f0d
3ac8db6
17edcc1
3670e4e
37b5895
40c5e03
a2ec041
399e82b
4a1790e
427d0d5
ef925c1
d39e54a
8efb5f1
d6f0ef3
400f654
1e32958
47759d3
9068ad4
c0bddce
76d6532
ddafec0
34f843f
08c2450
7208dae
1d95a06
c2082e9
6c255bd
e0be6f3
e6cfe9c
40eb079
64de141
12943c8
c1ffd03
c42b4c5
750ffc0
c5cb695
0e27016
fdd9958
4d633e3
17e52d3
601816b
4d0e8a3
a36e451
332b0db
f424040
705e805
07777d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,6 +117,11 @@ def train_fx(trial_hparams, cluster_manager, _): | |
import re | ||
from abc import ABC, abstractmethod | ||
from typing import Union | ||
import subprocess | ||
import sys | ||
from time import sleep | ||
import numpy as np | ||
from os.path import abspath | ||
|
||
import torch | ||
from pytorch_lightning import _logger as log | ||
|
@@ -311,7 +316,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids): | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||
|
||
# when slurm is managing the task it sets the visible devices | ||
if not is_slurm_managing_tasks: | ||
if not is_slurm_managing_tasks and 'CUDA_VISIBLE_DEVICES' not in os.environ: | ||
if isinstance(data_parallel_device_ids, int): | ||
id_str = ','.join(str(x) for x in list(range(data_parallel_device_ids))) | ||
os.environ["CUDA_VISIBLE_DEVICES"] = id_str | ||
|
@@ -322,7 +327,74 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids): | |
# don't make this debug... this is good UX | ||
log.info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]') | ||
|
||
def ddp_train(self, process_idx, model): | ||
def __set_random_port(self): | ||
""" | ||
When running DDP NOT managed by SLURM, the ports might collide | ||
:return: | ||
""" | ||
try: | ||
default_port = os.environ['MASTER_PORT'] | ||
except Exception: | ||
import random | ||
default_port = random.randint(10000, 19000) | ||
os.environ['MASTER_PORT'] = str(default_port) | ||
|
||
def spawn_ddp_children(self, model): | ||
self.__set_random_port() | ||
port = os.environ['MASTER_PORT'] | ||
|
||
master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR'] | ||
os.environ['MASTER_PORT'] = f'{port}' | ||
os.environ['MASTER_ADDR'] = f'{master_address}' | ||
|
||
# allow the user to pass the node rank | ||
node_rank = '0' | ||
if 'NODE_RANK' in os.environ: | ||
node_rank = os.environ['NODE_RANK'] | ||
if 'GROUP_RANK' in os.environ: | ||
node_rank = os.environ['GROUP_RANK'] | ||
|
||
os.environ['NODE_RANK'] = node_rank | ||
os.environ['LOCAL_RANK'] = '0' | ||
|
||
# pull out the commands used to run the script and resolve the abs file path | ||
command = sys.argv | ||
full_path = abspath(command[0]) | ||
command[0] = full_path | ||
command = ['python'] + command | ||
|
||
# since this script sets the visible devices we replace the gpus flag with a number | ||
num_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',').__len__() | ||
|
||
# if script called without a flag, pass in a flag anyhow | ||
if '--gpus' not in command: | ||
arg_gpus = len(self.gpus) if isinstance(self.gpus, list) else self.gpus | ||
command += ['--gpus', arg_gpus] | ||
|
||
gpu_flag_idx = command.index('--gpus') | ||
command[gpu_flag_idx + 1] = f'{num_gpus}' | ||
|
||
os.environ['WORLD_SIZE'] = f'{num_gpus * self.num_nodes}' | ||
|
||
self.interactive_ddp_procs = [] | ||
for local_rank in range(1, self.num_processes): | ||
env_copy = os.environ.copy() | ||
env_copy['LOCAL_RANK'] = f'{local_rank}' | ||
|
||
# import pdb; pdb.set_trace() | ||
# start process | ||
proc = subprocess.Popen(command, env=env_copy) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. starts the process using env flags and python command |
||
self.interactive_ddp_procs.append(proc) | ||
|
||
# starting all processes at once can cause issues | ||
# with dataloaders delay between 1-10 seconds | ||
delay = np.random.uniform(1, 5, 1)[0] | ||
williamFalcon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sleep(delay) | ||
|
||
local_rank = 0 | ||
self.ddp_train(local_rank, model, is_master=True) | ||
|
||
def ddp_train(self, process_idx, model, is_master=False): | ||
""" | ||
Entry point into a DP thread | ||
:param gpu_idx: | ||
|
@@ -359,7 +431,14 @@ def ddp_train(self, process_idx, model): | |
# MODEL | ||
# copy model to each gpu | ||
if self.on_gpu: | ||
self.root_gpu = process_idx | ||
gpu_idx = process_idx | ||
if is_master: | ||
# source of truth is cuda for gpu idx | ||
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') | ||
local_rank = int(os.environ['LOCAL_RANK']) | ||
gpu_idx = int(gpus[local_rank]) | ||
|
||
self.root_gpu = gpu_idx | ||
torch.cuda.set_device(self.root_gpu) | ||
model.cuda(self.root_gpu) | ||
|
||
|
@@ -388,9 +467,6 @@ def ddp_train(self, process_idx, model): | |
# continue training routine | ||
self.run_pretrain_routine(model) | ||
|
||
# when ddp ends, we save the model | ||
self.save_spawn_weights(model) | ||
|
||
def save_spawn_weights(self, model): | ||
""" | ||
Dump a temporary checkpoint after ddp ends to get weights out of the process | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,6 @@ | |
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities import rank_zero_warn, parsing | ||
|
||
|
||
try: | ||
from apex import amp | ||
except ImportError: | ||
|
@@ -119,7 +118,7 @@ def __init__( | |
distributed_backend: Optional[str] = None, | ||
precision: int = 32, | ||
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0 | ||
weights_summary: Optional[str] = 'full', | ||
weights_summary: Optional[str] = 'top', | ||
Comment on lines
-122
to
+121
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is not reflected in the docs There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mind shot a fix PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes unless it was not intentional. @williamFalcon yes or no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, forgot to update docs on this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added it here #2021, hope it's fine including it there. |
||
weights_save_path: Optional[str] = None, | ||
num_sanity_val_steps: int = 2, | ||
truncated_bptt_steps: Optional[int] = None, | ||
|
@@ -494,6 +493,7 @@ def __init__( | |
# init flags for SLURM+ddp to work | ||
self.proc_rank = 0 | ||
self.world_size = 1 | ||
self.interactive_ddp_procs = [] | ||
self.configure_slurm_ddp(self.num_nodes) | ||
self.node_rank = self.determine_ddp_node_rank() | ||
|
||
|
@@ -871,16 +871,12 @@ def fit( | |
task = int(os.environ['LOCAL_RANK']) | ||
self.ddp_train(task, model) | ||
|
||
else: | ||
self.__set_random_port() | ||
# track for predict | ||
elif self.distributed_backend == 'cpu_ddp': | ||
self.model = model | ||
# train | ||
mp.spawn(self.ddp_train, nprocs=self.num_processes, args=(model,)) | ||
# load weights if not interrupted | ||
if self.on_colab_kaggle: | ||
self.load_spawn_weights(model) | ||
self.model = model | ||
|
||
elif self.distributed_backend == 'ddp': | ||
self.spawn_ddp_children(model) | ||
|
||
# 1 gpu or dp option triggers training using DP module | ||
# easier to avoid NCCL issues | ||
|
@@ -928,18 +924,6 @@ def fit( | |
# used for testing or when we need to know that training succeeded | ||
return 1 | ||
|
||
def __set_random_port(self): | ||
""" | ||
When running DDP NOT managed by SLURM, the ports might collide | ||
:return: | ||
""" | ||
try: | ||
default_port = os.environ['MASTER_PORT'] | ||
except Exception: | ||
import random | ||
default_port = random.randint(10000, 19000) | ||
os.environ['MASTER_PORT'] = str(default_port) | ||
|
||
def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): | ||
# when dataloader is passed via fit, patch the train_dataloader | ||
# functions to overwrite with these implementations | ||
|
@@ -1046,7 +1030,10 @@ def run_pretrain_routine(self, model: LightningModule): | |
|
||
# clear cache before training | ||
if self.on_gpu: | ||
torch.cuda.empty_cache() | ||
# use context because of: | ||
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898 | ||
with torch.cuda.device(f'cuda:{self.root_gpu}'): | ||
torch.cuda.empty_cache() | ||
|
||
# CORE TRAINING LOOP | ||
self.train() | ||
|
@@ -1096,7 +1083,10 @@ def test( | |
if model is not None: | ||
self.model = model | ||
self.fit(model) | ||
elif self.use_ddp or self.use_tpu: # pragma: no-cover | ||
|
||
# on tpu, .spawn means we don't have a trained model | ||
# TODO: remove TPU spawn | ||
elif self.use_tpu: # pragma: no-cover | ||
# attempt to load weights from a spawn | ||
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt') | ||
test_model = self.model | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,6 +158,7 @@ def training_step(self, batch, batch_idx): | |
from pytorch_lightning.trainer.supporters import TensorRunningAccum | ||
from pytorch_lightning.utilities import rank_zero_warn | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
import subprocess | ||
|
||
try: | ||
from apex import amp | ||
|
@@ -305,13 +306,13 @@ def has_arg(self, *args): | |
|
||
def train(self): | ||
# add signal handlers for process kills | ||
def _signal_kill_handler(*args): | ||
return TrainerTrainLoopMixin.run_training_teardown(self) | ||
|
||
orig_signal_handlers = {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc: @justusschock There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as mentioned @justusschock i disabled this since it was a blocker. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm fine with temporarily disabling this. In our CI these exceptions did not appear |
||
for sig_name in SIGNAL_TERMINATE: | ||
orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name), | ||
_signal_kill_handler) | ||
# def _signal_kill_handler(*args): | ||
# return TrainerTrainLoopMixin.run_training_teardown(self) | ||
# | ||
# orig_signal_handlers = {} | ||
# for sig_name in SIGNAL_TERMINATE: | ||
# orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name), | ||
# _signal_kill_handler) | ||
|
||
# get model | ||
model = self.get_model() | ||
|
@@ -384,15 +385,17 @@ def _signal_kill_handler(*args): | |
|
||
self.run_training_teardown() | ||
|
||
# reset signal handlers | ||
for sig_name in SIGNAL_TERMINATE: | ||
signal.signal(getattr(signal, sig_name), orig_signal_handlers[sig_name]) | ||
|
||
except KeyboardInterrupt: | ||
if self.proc_rank == 0: | ||
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') | ||
self.interrupted = True | ||
self.run_training_teardown() | ||
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...') | ||
|
||
# user could press ctrl+c many times... only shutdown once | ||
if not self.interrupted: | ||
self.interrupted = True | ||
|
||
for proc in self.interactive_ddp_procs: | ||
subprocess.Popen.kill(proc) | ||
|
||
self.run_training_teardown() | ||
|
||
def run_training_epoch(self): | ||
|
||
|
@@ -678,7 +681,7 @@ def _get_optimizers_iterable(self): | |
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) | ||
return [(opt_idx, self.optimizers[opt_idx])] | ||
|
||
@atexit.register | ||
# @atexit.register | ||
def run_training_teardown(self): | ||
if hasattr(self, '_teardown_already_run') and self._teardown_already_run: | ||
return | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this assumes that it can be called only from cmd with Trainer like arguments...
but what about just a sprit with loaded params from a file...
cc @PyTorchLightning/core-contributors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you give an example?
these tests were done with cli flags and also trainer args. this local rank is coming from the gpus flag in the trainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also assumes that it is called as
(python) path/to/some/script.py
and not e.g. from a console entry pointThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in fact, this part of code is not tested...
https://codecov.io/gh/PyTorchLightning/pytorch-lightning/commit/82a20296e308a67c8d9202e4cbdf92a44b90b077