-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] ref: decoupled ddp, ddp spawn #3733
Changes from all commits
767b8ab
f746018
c3529ee
3497f0d
c4a9dc0
09bf2a6
81d7a0d
54a7402
2960aa2
417242c
b751f3a
e40a7c2
78bf07b
07efc8e
1276a51
d4b9f37
3041561
61ab801
7eeaa64
416a96d
b4454ee
f151c21
4278731
2e9c537
6f6f4fa
dab971d
95aaca6
b46874c
424a6db
35d01e4
f6e0bbe
a0542ae
64a486c
d124a94
3fa5ad2
2e49563
8acddd7
50a9c8b
5fc4912
2070075
f0c06bd
08b0cad
8a8a0bf
ed675ef
336bb47
c3f299a
e4cb76d
94ef3b9
357d640
e49c8a1
91736e2
15e5be0
b37d948
51370ce
23032ea
9f8705a
0f13e61
7ccabd8
9171464
1d4aeaa
b96d7c1
85050a3
506b037
63f5d50
01dd4c5
a0f52d7
650903a
cbd89f7
8ebd4ed
1f19c2f
ea448bb
fbeec9e
7663c6b
9421dbb
cf08480
f0c3cc5
459a0fa
64484a1
10bae5b
667c434
5b412e0
d9fc538
b2e941c
5ac3e59
3650f86
da582ab
471b576
545bf01
7b72cd6
1fbc1ca
c5c9faf
701f233
4a7368a
7169107
27e5870
455a488
6c3732c
73f0ef3
e36e20f
2f93660
1fb466c
202e82e
c8bd6ee
d4d8551
5acef3e
288fd23
0dcdd81
581e929
fe53c9a
c644f66
2a10f59
beacd6a
7e98763
661cfb0
69235e9
c958ec7
6088c48
f86ab63
2c2755c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -15,15 +15,15 @@ | |||||
import re | ||||||
|
||||||
import torch | ||||||
import torch.multiprocessing as mp | ||||||
import torch.distributed as torch_distrib | ||||||
import torch.multiprocessing as mp | ||||||
import torch.distributed as dist | ||||||
|
||||||
from pytorch_lightning import _logger as log | ||||||
from pytorch_lightning.accelerators.base_backend import Accelerator | ||||||
from pytorch_lightning.utilities import AMPType | ||||||
from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load | ||||||
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn | ||||||
from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load | ||||||
from pytorch_lightning.utilities.seed import seed_everything | ||||||
from pytorch_lightning.distributed.dist import LightningDistributed | ||||||
from pytorch_lightning.utilities.distributed import find_free_network_port | ||||||
|
@@ -157,12 +157,109 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 | |||||
# clean up memory | ||||||
torch.cuda.empty_cache() | ||||||
|
||||||
def ddp_train(self, process_idx, mp_queue, model): | ||||||
""" | ||||||
Entry point for ddp | ||||||
|
||||||
Args: | ||||||
process_idx: | ||||||
mp_queue: multiprocessing queue | ||||||
model: | ||||||
|
||||||
Returns: | ||||||
|
||||||
""" | ||||||
# show progressbar only on progress_rank 0 | ||||||
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None: | ||||||
self.trainer.progress_bar_callback.disable() | ||||||
|
||||||
# determine which process we are and world size | ||||||
self.set_world_ranks(process_idx) | ||||||
|
||||||
# set warning rank | ||||||
rank_zero_only.rank = self.trainer.global_rank | ||||||
|
||||||
# set up server using proc 0's ip address | ||||||
# try to init for 20 times at max in case ports are taken | ||||||
# where to store ip_table | ||||||
model.trainer = self.trainer | ||||||
model.init_ddp_connection( | ||||||
self.trainer.global_rank, | ||||||
self.trainer.world_size, | ||||||
self.trainer.is_slurm_managing_tasks | ||||||
) | ||||||
|
||||||
# call setup after the ddp process has connected | ||||||
self.trainer.call_setup_hook(model) | ||||||
|
||||||
# on world_size=0 let everyone know training is starting | ||||||
if self.trainer.is_global_zero and not torch.distributed.is_initialized(): | ||||||
log.info('-' * 100) | ||||||
log.info(f'distributed_backend={self.trainer.distributed_backend}') | ||||||
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') | ||||||
log.info('-' * 100) | ||||||
|
||||||
# call sync_bn before .cuda(), configure_apex and configure_ddp | ||||||
if self.trainer.sync_batchnorm: | ||||||
model = model.configure_sync_batchnorm(model) | ||||||
|
||||||
# move the model to the correct device | ||||||
self.model_to_device(model, process_idx) | ||||||
|
||||||
# CHOOSE OPTIMIZER | ||||||
# allow for lr schedulers as well | ||||||
self.setup_optimizers(model) | ||||||
|
||||||
# set model properties before going into wrapper | ||||||
self.trainer.model_connector.copy_trainer_model_properties(model) | ||||||
|
||||||
# 16-bit | ||||||
model = self.trainer.precision_connector.connect(model) | ||||||
|
||||||
# device ids change depending on the DDP setup | ||||||
device_ids = self.get_device_ids() | ||||||
|
||||||
# allow user to configure ddp | ||||||
model = model.configure_ddp(model, device_ids) | ||||||
|
||||||
# set up training routine | ||||||
self.trainer.train_loop.setup_training(model) | ||||||
|
||||||
# train or test | ||||||
results = self.train_or_test() | ||||||
|
||||||
# get original model | ||||||
model = self.trainer.get_model() | ||||||
|
||||||
# persist info in ddp_spawn | ||||||
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) | ||||||
|
||||||
def training_step(self, args): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
if self.trainer.amp_backend == AMPType.NATIVE: | ||||||
with torch.cuda.amp.autocast(): | ||||||
output = self.trainer.model(*args) | ||||||
else: | ||||||
output = self.trainer.model(*args) | ||||||
return output | ||||||
|
||||||
def validation_step(self, args): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
output = self.training_step(args) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, can we not call |
||||||
return output | ||||||
|
||||||
def test_step(self, args): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
output = self.training_step(args) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return output | ||||||
|
||||||
def barrier(self, name: str = None): | ||||||
if torch_distrib.is_initialized(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused name argument |
||||||
torch_distrib.barrier() | ||||||
|
||||||
def set_world_ranks(self, process_idx): | ||||||
self.trainer.local_rank = process_idx | ||||||
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx | ||||||
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes | ||||||
|
||||||
def model_to_device(self, model, process_idx, is_master): | ||||||
def model_to_device(self, model, process_idx): | ||||||
gpu_idx = process_idx | ||||||
self.trainer.root_gpu = gpu_idx | ||||||
torch.cuda.set_device(self.trainer.root_gpu) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the processes communicate on startup? I feel like a hardcoded sleep is not the optimal solution here