Skip to content

Commit

Permalink
Finish Ananthsub patch 1 (enable prepare_data from correct processes)…
Browse files Browse the repository at this point in the history
…. clarify local vs global rank (#2166)

* [trainer] Call prepare_data once per node in DDP/DDP2 training

* refactored DDP routes

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* renamed proc_rank to local_rank

* spawn message

* spawn message

* spawn message

* fixes

* fixes

* fixes

* fixes

* fixes

* Update trainer.py

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
  • Loading branch information
williamFalcon and ananthsub authored Jun 13, 2020
1 parent 10c643f commit 5fd01b0
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 46 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
if trainer.proc_rank != 0:
if trainer.global_rank != 0:
return

metrics = trainer.callback_metrics
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, x):
self.print(x, 'in forward')
"""
if self.trainer.proc_rank == 0:
if self.trainer.is_global_zero:
print(*args, **kwargs)

@abstractmethod
Expand Down Expand Up @@ -922,7 +922,7 @@ def _init_slurm_connection(self) -> None:

def init_ddp_connection(
self,
proc_rank: int,
global_rank: int,
world_size: int,
is_slurm_managing_tasks: bool = True
) -> None:
Expand All @@ -933,7 +933,7 @@ def init_ddp_connection(
for SLURM managed cluster.
Args:
proc_rank: The current process rank within the node.
global_rank: The global process idx.
world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus).
is_slurm_managing_tasks: is cluster managed by SLURM.
Expand All @@ -956,8 +956,8 @@ def init_ddp_connection(
f"is not equal to the computed world size ({world_size}). Ignored.")

torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
log.info(f"initializing ddp: LOCAL_RANK: {proc_rank}/{world_size - 1} WORLD_SIZE:{world_size}")
torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size)
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank+1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def configure_apex(
self,
Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,19 @@ def on_train_end(self, trainer, pl_module):
--env=XLA_USE_BF16=1
-- python your_trainer_file.py
prepare_data_per_node
^^^^^^^^^^^^^^^^^^^^^
If True will call `prepare_data()` on LOCAL_RANK=0 for every node.
If False will only call from NODE_RANK=0, LOCAL_RANK=0
Example::
# default
Trainer(prepare_data_per_node=True)
# use only NODE_RANK=0, LOCAL_RANK=0
Trainer(prepare_data_per_node=False)
tpu_cores
^^^^^^^^^
- How many TPU cores to train on (1 or 8).
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TrainerDataLoadingMixin(ABC):

# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
proc_rank: int
global_rank: int
use_ddp: bool
use_ddp2: bool
use_horovod: bool
Expand Down Expand Up @@ -160,7 +160,7 @@ def _get_distributed_sampler(self, dataloader):
'ddp_cpu': self.num_processes * self.num_nodes
}
assert self.distributed_backend is not None
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank)
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
sampler = DistributedSampler(dataloader.dataset, **kwargs)
return sampler

Expand Down
35 changes: 27 additions & 8 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ class TrainerDDPMixin(ABC):
num_nodes: int
node_rank: int

@property
def is_global_zero(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""

@property
@abstractmethod
def num_gpus(self) -> int:
Expand Down Expand Up @@ -300,6 +304,13 @@ def configure_slurm_ddp(self, num_gpu_nodes):
if self.is_slurm_managing_tasks:
rank_zero_info('Multi-processing is handled by Slurm.')

def determine_local_rank(self):
if self.is_slurm_managing_tasks:
return int(os.environ['SLURM_LOCALID'])

else:
return int(os.environ.get('LOCAL_RANK', 0))

def determine_ddp_node_rank(self):
if self.is_slurm_managing_tasks:
return int(os.environ['SLURM_NODEID'])
Expand Down Expand Up @@ -423,21 +434,30 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):

# determine which process we are and world size
if self.use_ddp:
self.proc_rank = self.node_rank * self.num_processes + process_idx
self.local_rank = process_idx
self.global_rank = self.node_rank * self.num_processes + process_idx
self.world_size = self.num_nodes * self.num_processes

elif self.use_ddp2:
self.proc_rank = self.node_rank
self.local_rank = self.node_rank
self.global_rank = self.node_rank
self.world_size = self.num_nodes

# set warning rank
rank_zero_only.rank = self.proc_rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self
model.init_ddp_connection(self.proc_rank, self.world_size, self.is_slurm_managing_tasks)
model.init_ddp_connection(self.global_rank, self.world_size, self.is_slurm_managing_tasks)

# on world_size=0 let everyone know training is starting
if self.is_global_zero:
log.info('-' * 100)
log.info(f'distributed_backend={self.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.world_size} processes')
log.info('-' * 100)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand All @@ -450,8 +470,7 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
if is_master:
# source of truth is cuda for gpu idx
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
local_rank = int(os.environ['LOCAL_RANK'])
gpu_idx = int(gpus[local_rank])
gpu_idx = int(gpus[self.local_rank])

self.root_gpu = gpu_idx
torch.cuda.set_device(self.root_gpu)
Expand Down Expand Up @@ -488,7 +507,7 @@ def save_spawn_weights(self, model):
:param model:
:return:
"""
if self.proc_rank == 0:
if self.is_global_zero:
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)

Expand All @@ -502,7 +521,7 @@ def load_spawn_weights(self, original_model):

loaded_model = original_model

if self.proc_rank == 0:
if self.is_global_zero:
# load weights saved in ddp
path = os.path.join(self.default_root_dir, '__temp_weight_ddp_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TrainerDPMixin(ABC):
root_gpu: ...
amp_level: str
precision: ...
proc_rank: int
global_rank: int
tpu_local_core_rank: int
tpu_global_core_rank: int
use_tpu: bool
Expand Down Expand Up @@ -183,8 +183,8 @@ def tpu_train(self, tpu_core_idx, model):
if self.tpu_global_core_rank != 0 and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()

self.proc_rank = self.tpu_local_core_rank
rank_zero_only.rank = self.proc_rank
self.global_rank = self.tpu_local_core_rank
rank_zero_only.rank = self.global_rank

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
Expand Down Expand Up @@ -289,8 +289,8 @@ def filter_named_parameters(model, optimizer):

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.proc_rank = hvd.rank()
rank_zero_only.rank = self.proc_rank
self.global_rank = hvd.rank()
rank_zero_only.rank = self.global_rank

with ExitStack() as stack:
for optimizer in self.optimizers:
Expand Down
10 changes: 3 additions & 7 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@
from torch.utils.data import DataLoader

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.profiler.profilers import BaseProfiler
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -160,8 +159,6 @@ class TrainerEvaluationLoopMixin(ABC):
use_dp: bool
use_ddp2: bool
use_horovod: bool
use_amp: bool
use_native_amp: bool
single_gpu: bool
data_parallel_device_ids: ...
model: LightningModule
Expand All @@ -170,15 +167,14 @@ class TrainerEvaluationLoopMixin(ABC):
fast_dev_run: ...
process_output: ...
progress_bar_dict: ...
proc_rank: int
global_rank: int
current_epoch: int
callback_metrics: ...
test_dataloaders: DataLoader
val_dataloaders: DataLoader
use_tpu: bool
reload_dataloaders_every_epoch: ...
tpu_id: Optional[int]
profiler: BaseProfiler
tpu_id: int

# Callback system
on_validation_batch_start: Callable
Expand Down Expand Up @@ -379,7 +375,7 @@ def run_evaluation(self, test_mode: bool = False):
self.add_progress_bar_metrics(prog_bar_metrics)

# log results of test
if test_mode and self.proc_rank == 0:
if test_mode and self.is_global_zero:
print('-' * 80)
print('TEST RESULTS')
pprint(callback_metrics)
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Iterable, Optional
from typing import Union, Iterable

import torch

Expand All @@ -15,10 +15,10 @@ class TrainerLoggingMixin(ABC):
current_epoch: int
on_gpu: bool
log_gpu_memory: ...
logger: Optional[LightningLoggerBase]
logger: Union[LightningLoggerBase, bool]
progress_bar_metrics: ...
global_step: int
proc_rank: int
global_rank: int
use_dp: bool
use_ddp2: bool
default_root_dir: str
Expand Down Expand Up @@ -69,7 +69,7 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
scalar_metrics['epoch'] = self.current_epoch
step = step if step is not None else self.global_step
# log actual metrics
if self.proc_rank == 0 and self.logger is not None:
if self.is_global_zero and self.logger is not None:
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.save()

Expand Down
30 changes: 24 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
replace_sampler_ddp: bool = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: bool = True,
amp_level: str = 'O1', # backward compatible, todo: remove in v1.0.0
num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0
use_amp=None, # backward compatible, todo: remove in v0.9.0
Expand Down Expand Up @@ -282,6 +283,9 @@ def __init__(
The result will be stored in self.batch_size in the LightningModule.
Additionally, can be set to either `power` that estimates the batch size through
a power search or `binsearch` that estimates the batch size through a binary search.
prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data.
Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data
"""
super().__init__()

Expand All @@ -293,6 +297,7 @@ def __init__(
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

# Init callbacks
self.prepare_data_per_node = prepare_data_per_node
self.callbacks = callbacks or []
self.on_init_start()

Expand Down Expand Up @@ -439,11 +444,12 @@ def __init__(
self.init_tpu()

# init flags for SLURM+ddp to work
self.proc_rank = 0
self.world_size = 1
self.interactive_ddp_procs = []
self.configure_slurm_ddp(self.num_nodes)
self.node_rank = self.determine_ddp_node_rank()
self.local_rank = self.determine_local_rank()
self.global_rank = 0

# nvidia setup
self.set_nvidia_flags(self.is_slurm_managing_tasks, self.data_parallel_device_ids)
Expand Down Expand Up @@ -481,6 +487,10 @@ def __init__(
# Callback system
self.on_init_end()

@property
def is_global_zero(self):
return self.global_rank == 0

@property
def slurm_job_id(self) -> Optional[int]:
try:
Expand Down Expand Up @@ -532,6 +542,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
('max_epochs', (<class 'int'>,), 1000),
...
('precision', (<class 'int'>,), 32),
('prepare_data_per_node', (<class 'bool'>,), True),
('print_nan_grads', (<class 'bool'>,), False),
('process_position', (<class 'int'>,), 0),
('profiler',
Expand Down Expand Up @@ -773,10 +784,9 @@ def fit(
# check that model is configured correctly
self.check_model_configuration(model)

# download the data and do whatever transforms we need
# do before any spawn calls so that the model can assign properties
# only on proc 0 because no spawn has happened yet
if not self._is_data_prepared:
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
model.prepare_data()
self._is_data_prepared = True

Expand All @@ -801,6 +811,7 @@ def fit(
# torchelastic or general non_slurm ddp2
elif 'WORLD_SIZE' in os.environ and ('GROUP_RANK' in os.environ or 'NODE_RANK' in os.environ):
task = int(os.environ['LOCAL_RANK'])

self.ddp_train(task, model)
elif self.use_ddp:
if self.is_slurm_managing_tasks:
Expand Down Expand Up @@ -872,6 +883,13 @@ def fit(
# used for testing or when we need to know that training succeeded
return 1

def can_prepare_data(self):
if self.prepare_data_per_node:
return self.local_rank == 0

else:
return self.node_rank == 0 and self.local_rank == 0

def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
Expand Down Expand Up @@ -928,7 +946,7 @@ def run_pretrain_routine(self, model: LightningModule):

# print model summary
# TODO: remove self.testing condition because model.summarize() is wiping out the weights
if self.proc_rank == 0 and self.weights_summary is not None and not self.testing:
if self.is_global_zero and self.weights_summary is not None and not self.testing:
if self.weights_summary in ['full', 'top']:
ref_model.summarize(mode=self.weights_summary)
else:
Expand Down
Loading

0 comments on commit 5fd01b0

Please sign in to comment.