Skip to content

Commit

Permalink
ref: callback system and init ddp (1/n) (#3836)
Browse files Browse the repository at this point in the history
* refactored callback system and init ddp

* refactored callback system and init ddp

* refactored callback system and init ddp

* refactored callback system and init ddp
  • Loading branch information
williamFalcon authored Oct 4, 2020
1 parent b8a6408 commit 1f8ff7c
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 99 deletions.
2 changes: 1 addition & 1 deletion benchmarks/test_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@pytest.mark.parametrize('cls_model,max_diff', [
(ParityModuleRNN, 0.05),
(ParityModuleMNIST, 0.55)
(ParityModuleMNIST, 0.57)
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_pytorch_parity(tmpdir, cls_model, max_diff):
Expand Down
6 changes: 0 additions & 6 deletions docs/source/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1048,12 +1048,6 @@ get_progress_bar_dict
.. autofunction:: pytorch_lightning.core.lightning.LightningModule.get_progress_bar_dict
:noindex:

init_ddp_connection
~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.core.lightning.LightningModule.init_ddp_connection
:noindex:

tbptt_split_batch
~~~~~~~~~~~~~~~~~

Expand Down
55 changes: 55 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import math
from enum import Enum
from typing import Any
Expand All @@ -8,6 +9,8 @@
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log

try:
from apex import amp
Expand Down Expand Up @@ -185,6 +188,58 @@ def setup_optimizers(self, model):
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

def init_ddp_connection(
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
) -> None:
if is_slurm_managing_tasks:
self.trainer.slurm_connector.connect_ddp(global_rank, world_size)
else:
self.connect_torchelastic(global_rank, world_size)

def connect_torchelastic(
self, global_rank: int, world_size: int
) -> None:
"""
Override to define your custom way of setting up a distributed environment.
Lightning's implementation uses env:// init by default and sets the first node as root
for SLURM managed cluster.
Args:
global_rank: The global process idx.
world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus).
"""

if "MASTER_ADDR" not in os.environ:
rank_zero_warn(
"MASTER_ADDR environment variable is not defined. Set as localhost"
)
os.environ["MASTER_ADDR"] = "127.0.0.1"
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")

if "MASTER_PORT" not in os.environ:
rank_zero_warn(
"MASTER_PORT environment variable is not defined. Set as 12910"
)
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) != world_size:
rank_zero_warn(
f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
f"is not equal to the computed world size ({world_size}). Ignored."
)

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"

if not torch.distributed.is_initialized():
log.info(
f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
)
torch_distrib.init_process_group(
torch_backend, rank=global_rank, world_size=world_size
)


# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
class BackendType(Enum):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def ddp_train(self, process_idx, mp_queue, model):
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def ddp_train(self, process_idx, model):
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def ddp_train(self, process_idx, mp_queue, model):
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_slurm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def ddp_train(self, process_idx, model):
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_torchelastic_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def ddp_train(self, process_idx, model):
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
Expand Down
88 changes: 2 additions & 86 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import torch
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
Expand Down Expand Up @@ -956,87 +955,6 @@ def configure_ddp(self, model, device_ids):
)
return model

def _init_slurm_connection(self) -> None:
""""""
"""
Sets up environment variables necessary for pytorch distributed communications
based on slurm environment.
"""
# use slurm job id for the port number
# guarantees unique ports across jobs from same grid search
try:
# use the last 4 numbers in the job id as the id
default_port = os.environ["SLURM_JOB_ID"]
default_port = default_port[-4:]

# all ports should be in the 10k+ range
default_port = int(default_port) + 15000

except Exception:
default_port = 12910

# if user gave a port number, use that one instead
try:
default_port = os.environ["MASTER_PORT"]
except Exception:
os.environ["MASTER_PORT"] = str(default_port)

# figure out the root node addr
try:
root_node = os.environ["SLURM_NODELIST"].split(" ")[0]
except Exception:
root_node = "127.0.0.1"

root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node)
os.environ["MASTER_ADDR"] = root_node

def init_ddp_connection(
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
) -> None:
"""
Override to define your custom way of setting up a distributed environment.
Lightning's implementation uses env:// init by default and sets the first node as root
for SLURM managed cluster.
Args:
global_rank: The global process idx.
world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus).
is_slurm_managing_tasks: is cluster managed by SLURM.
"""
if is_slurm_managing_tasks:
self._init_slurm_connection()

if "MASTER_ADDR" not in os.environ:
rank_zero_warn(
"MASTER_ADDR environment variable is not defined. Set as localhost"
)
os.environ["MASTER_ADDR"] = "127.0.0.1"
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")

if "MASTER_PORT" not in os.environ:
rank_zero_warn(
"MASTER_PORT environment variable is not defined. Set as 12910"
)
os.environ["MASTER_PORT"] = "12910"
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

if "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) != world_size:
rank_zero_warn(
f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
f"is not equal to the computed world size ({world_size}). Ignored."
)

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"

if not torch.distributed.is_initialized():
log.info(
f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
)
torch_distrib.init_process_group(
torch_backend, rank=global_rank, world_size=world_size
)

def configure_sync_batchnorm(self, model: "LightningModule") -> "LightningModule":
"""
Add global batchnorm for a model spread across multiple GPUs and nodes.
Expand Down Expand Up @@ -1089,10 +1007,8 @@ def configure_apex(self, amp, model, optimizers, amp_level):
return model, optimizers

def configure_optimizers(
self,
) -> Optional[
Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]
]:
self,
):
r"""
Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
Expand Down
48 changes: 48 additions & 0 deletions pytorch_lightning/trainer/connectors/slurm_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from subprocess import call
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.distributed import rank_zero_info
import torch.distributed as torch_distrib
import torch


class SLURMConnector:
Expand Down Expand Up @@ -101,3 +103,49 @@ def sig_handler(self, signum, frame): # pragma: no-cover
def term_handler(self, signum, frame):
# save
log.info("bypassing sigterm")

def connect_ddp(self, global_rank: int, world_size: int) -> None:
""""""
"""
Sets up environment variables necessary for pytorch distributed communications
based on slurm environment.
"""
# use slurm job id for the port number
# guarantees unique ports across jobs from same grid search
try:
# use the last 4 numbers in the job id as the id
default_port = os.environ["SLURM_JOB_ID"]
default_port = default_port[-4:]

# all ports should be in the 10k+ range
default_port = int(default_port) + 15000

except Exception:
default_port = 12910

# if user gave a port number, use that one instead
try:
default_port = os.environ["MASTER_PORT"]
except Exception:
os.environ["MASTER_PORT"] = str(default_port)
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

# figure out the root node addr
try:
root_node = os.environ["SLURM_NODELIST"].split(" ")[0]
except Exception:
root_node = "127.0.0.1"

root_node = self.trainer.slurm_connector.resolve_root_node_address(root_node)
os.environ["MASTER_ADDR"] = root_node
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"

if not torch.distributed.is_initialized():
log.info(
f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
)
torch_distrib.init_process_group(
torch_backend, rank=global_rank, world_size=world_size
)

0 comments on commit 1f8ff7c

Please sign in to comment.