From e9324a8d44afff6774a06a990042923a60298212 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Thu, 11 Jan 2024 14:19:56 -0500 Subject: [PATCH] ddp fixes and improvements --- .../rastervision/pytorch_learner/learner.py | 333 +++++++++--------- .../pytorch_learner/utils/__init__.py | 3 + .../pytorch_learner/utils/distributed.py | 98 ++++++ 3 files changed, 273 insertions(+), 161 deletions(-) create mode 100644 rastervision_pytorch_learner/rastervision/pytorch_learner/utils/distributed.py diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py index 9555174174..b6d24ced7a 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner.py @@ -1,16 +1,16 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Union, Type) from abc import ABC, abstractmethod -import os from os.path import join, isfile, basename, isdir import warnings -import time +from time import perf_counter import datetime import shutil import logging from subprocess import Popen import numbers from pprint import pformat +import gc import numpy as np from tqdm.auto import tqdm @@ -36,7 +36,7 @@ save_pipeline_config) from rastervision.pytorch_learner.utils import ( get_hubconf_dir_from_cfg, aggregate_metrics, log_metrics_to_csv, - log_system_details, ONNXRuntimeAdapter) + log_system_details, ONNXRuntimeAdapter, DDPContextManager) from rastervision.pytorch_learner.dataset.visualizer import Visualizer if TYPE_CHECKING: @@ -54,8 +54,6 @@ TRANSFORMS_DIRNAME = 'custom_albumentations_transforms' BUNDLE_MODEL_WEIGHTS_FILENAME = 'model.pth' BUNDLE_MODEL_ONNX_FILENAME = 'model.onnx' -DDP_BACKEND = rv_config.get_namespace_option('rastervision', 'DDP_BACKEND', - 'nccl') log = logging.getLogger(__name__) @@ -68,9 +66,9 @@ class Learner(ABC): This can be subclassed to handle different computer vision tasks. The datasets, model, optimizer, and schedulers will be generated from the - cfg if not specified in the constructor. + :class:`.LearnerConfig` if not specified in the constructor. - If instantiated with `training=False`, the training apparatus (loss, + If instantiated with ``training=False``, the training apparatus (loss, optimizer, scheduler, logging, etc.) will not be set up and the model will be put into eval mode. @@ -390,61 +388,74 @@ def from_model_bundle(cls: Type, def main(self): """Main training sequence. + This plots the dataset, runs a training and validation loop (which will + resume if interrupted), logs stats, plots predictions, and syncs + results to the cloud. + """ + if self.distributed: + with self.ddp(): + self._main() + else: + self._main() + + def _main(self): + """Main training sequence. + This plots the dataset, runs a training and validation loop (which will resume if interrupted), logs stats, plots predictions, and syncs results to the cloud. """ cfg = self.cfg - if not self.avoid_activating_cuda_runtime: - log_system_details() - log.info(cfg) + if not self.is_ddp_process or self.is_ddp_local_master: + if not self.avoid_activating_cuda_runtime: + log_system_details() + log.info(cfg) log.info(f'Using device: {self.device}') - self.log_data_stats() - self.run_tensorboard() - if not self.avoid_activating_cuda_runtime: - self.plot_dataloaders(cfg.data.preview_batch_limit) + if not self.distributed: + self.run_tensorboard() + self.train() if cfg.save_model_bundle: - self.save_model_bundle() - + if not self.is_ddp_process or self.is_ddp_master: + self.save_model_bundle() self.stop_tensorboard() if cfg.eval_train: self.validate('train') self.validate('valid') - self.sync_to_cloud() + if not self.is_ddp_process or self.is_ddp_master: + self.sync_to_cloud() ########################### # Training and validation ########################### def train(self, epochs: Optional[int] = None): - """Run training loop.""" + """Run training loop, resuming training if appropriate""" + start_epoch, end_epoch = self.get_start_and_end_epochs(epochs) + + if start_epoch >= end_epoch: + log.info('Training already completed. Skipping.') + return + + if (start_epoch > 0 and start_epoch < end_epoch): + log.info('Resuming training from epoch %d', start_epoch) + if self.is_ddp_process: # pragma: no cover self._run_train_distributed(self.ddp_rank, self.ddp_world_size, - epochs) + start_epoch, end_epoch) elif self.distributed: # pragma: no cover log.info('Spawning %d DDP processes', self.ddp_world_size) mp.start_processes( self._run_train_distributed, - args=(self.ddp_world_size, epochs), + args=(self.ddp_world_size, start_epoch, end_epoch), nprocs=self.ddp_world_size, join=True, start_method=self.ddp_start_method) else: - self._train(epochs) - - def _train(self, epochs: Optional[int] = None): # pragma: no cover - """Training loop that will attempt to resume training if appropriate.""" - start_epoch = self.get_start_epoch() - - if epochs is None: - end_epoch = self.cfg.solver.num_epochs - else: - end_epoch = start_epoch + epochs - - if (start_epoch > 0 and start_epoch < end_epoch): - log.info(f'Resuming training from epoch {start_epoch}') + self._train(start_epoch, end_epoch) + def _train(self, start_epoch: int, end_epoch: int): # pragma: no cover + """Training loop.""" self.on_train_start() for epoch in range(start_epoch, end_epoch): log.info(f'epoch: {epoch}') @@ -462,20 +473,9 @@ def _train(self, epochs: Optional[int] = None): # pragma: no cover self.on_epoch_end(epoch, metrics) - def _train_distributed(self, - epochs: Optional[int] = None): # pragma: no cover - """Training loop that will attempt to resume training if appropriate. - """ - start_epoch = self.get_start_epoch() - - if epochs is None: - end_epoch = self.cfg.solver.num_epochs - else: - end_epoch = start_epoch + epochs - - if (start_epoch > 0 and start_epoch < end_epoch): - log.info(f'Resuming training from epoch {start_epoch}') - + def _train_distributed(self, start_epoch: int, + end_epoch: int): # pragma: no cover + """Distributed training loop.""" if self.is_ddp_master: self.on_train_start() @@ -506,40 +506,12 @@ def _train_distributed(self, def _run_train_distributed(self, rank: int, world_size: int, *args): # pragma: no cover """Method executed by each DDP worker.""" - - os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost') - os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '12355') - - dist.init_process_group(DDP_BACKEND, rank=rank, world_size=world_size) - - if self.ddp_rank is None: - self.ddp_rank = rank - if self.ddp_local_rank is None: - # Implies process was spawned by self.train(), and therefore, - # this is necessarily a single-node multi-GPU scenario. - # So global rank == local rank. - self.ddp_local_rank = rank - if self.ddp_world_size is None: - self.ddp_world_size = world_size - - log.info('DDP rank: %d, DDP local rank: %d', self.ddp_rank, - self.ddp_local_rank) - - self.is_ddp_process = True - self.is_ddp_master = self.ddp_rank == 0 - if self.device.index is None: - self.device = torch.device(self.device.type, self.ddp_local_rank) - torch.cuda.set_device(self.device) - - self.setup_model( - model_weights_path=self.init_model_weights_path, - model_def_path=self.init_model_def_path) - self.setup_training() - dist.barrier() - - self._train_distributed(*args) - - dist.destroy_process_group() + with self.ddp(rank, world_size): + self.setup_model( + model_weights_path=self.init_model_weights_path, + model_def_path=self.init_model_def_path) + self.setup_training(self.init_loss_def_path) + self._train_distributed(*args) def train_epoch( self, @@ -550,7 +522,7 @@ def train_epoch( self.model.train() if dataloader is None: dataloader = self.train_dl - start = time.time() + start = perf_counter() outputs = [] if self.ddp_rank is not None: desc = f'Training (GPU={self.ddp_rank})' @@ -574,7 +546,7 @@ def train_epoch( if len(outputs) == 0: raise ValueError('Training dataset did not return any batches') metrics = self.train_end(outputs) - end = time.time() + end = perf_counter() train_time = datetime.timedelta(seconds=end - start) metrics['train_time'] = str(train_time) return metrics @@ -593,6 +565,8 @@ def train_step(self, batch: Any, batch_ind: int) -> MetricDict: def on_train_start(self): """Hook that is called at start of train routine.""" + self.log_data_stats() + self.plot_dataloaders(self.cfg.data.preview_batch_limit) def train_end(self, outputs: List[Dict[str, Union[float, Tensor]]] ) -> MetricDict: @@ -647,41 +621,16 @@ def _validate(self, split: Literal['train', 'valid', 'test'] = 'valid' def _run_validate_distributed(self, rank: int, world_size: int, *args): # pragma: no cover """Method executed by each DDP worker.""" - - os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost') - os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '12355') - - dist.init_process_group(DDP_BACKEND, rank=rank, world_size=world_size) - - if self.ddp_rank is None: - self.ddp_rank = rank - if self.ddp_local_rank is None: - self.ddp_local_rank = rank - if self.ddp_world_size is None: - self.ddp_world_size = world_size - - log.info('DDP rank: %d, DDP local rank: %d', self.ddp_rank, - self.ddp_local_rank) - - self.is_ddp_process = True - self.is_ddp_master = self.ddp_rank == 0 - if self.device.index is None: - self.device = torch.device(self.device.type, self.ddp_local_rank) - torch.cuda.set_device(self.device) - - self.setup_model( - model_weights_path=self.init_model_weights_path, - model_def_path=self.init_model_def_path) - self.setup_loss(self.init_loss_def_path) - dist.barrier() - - self._validate(*args) - - dist.destroy_process_group() + with self.ddp(rank, world_size): + self.setup_model( + model_weights_path=self.init_model_weights_path, + model_def_path=self.init_model_def_path) + self.setup_training(self.init_loss_def_path) + self._validate(*args) def validate_epoch(self, dl: DataLoader) -> MetricDict: """Validate for a single epoch.""" - start = time.time() + start = perf_counter() self.model.eval() outputs = [] if self.ddp_rank is not None: @@ -696,7 +645,7 @@ def validate_epoch(self, dl: DataLoader) -> MetricDict: batch = (x, y) output = self.validate_step(batch, batch_ind) outputs.append(output) - end = time.time() + end = perf_counter() validate_time = datetime.timedelta(seconds=end - start) metrics = self.validate_end(outputs) @@ -994,27 +943,23 @@ def setup_ddp_params(self): self.is_ddp_process = False self.is_ddp_master = False + self.is_ddp_local_master = False self.avoid_activating_cuda_runtime = False self.ddp_world_size = get_env_var('WORLD_SIZE', None, int) self.ddp_rank = get_env_var('RANK', None, int) self.ddp_local_rank = get_env_var('LOCAL_RANK', None, int) + ddp_vars_set = all( + v is not None + for v in [self.ddp_world_size, self.ddp_rank, self.ddp_local_rank]) - if dist.is_initialized(): # pragma: no cover - if not ddp_allowed: - log.info('Ignoring RASTERVISION_USE_DDP since DDP is already ' - 'initialized.') - ddp_vars_set = all( - [self.ddp_world_size, self.ddp_rank, self.ddp_local_rank]) - if not ddp_vars_set: - raise ValueError( - 'Is DDP process but WORLD_SIZE, RANK, and LOCAL_RANK ' - 'env variables not set.') + if not ddp_allowed: + self.distributed = False + elif ddp_vars_set: # pragma: no cover self.distributed = True self.is_ddp_process = True self.is_ddp_master = self.ddp_rank == 0 - elif not ddp_allowed: - self.distributed = False + self.is_ddp_local_master = self.ddp_local_rank == 0 elif self.ddp_start_method != 'spawn': # If ddp_start_method is "fork" or "forkserver", the CUDA runtime # must not be initialized before the fork; otherwise, a @@ -1048,13 +993,16 @@ def setup_ddp_params(self): if not self.distributed: return - self.ddp_world_size: int + if not self.training: + raise NotImplementedError( + 'DDP is currently only supported in training mode.') if self.model is not None: raise ValueError( 'In distributed mode, the model must be specified via ' 'ModelConfig in LearnerConfig rather than be passed ' 'as an instantiated object.') + dses_passed = any([self.train_ds, self.valid_ds, self.test_ds]) if dses_passed and self.ddp_start_method != 'fork': raise ValueError( @@ -1062,6 +1010,7 @@ def setup_ddp_params(self): 'RASTERVISION_DDP_START_METHOD != "fork", datasets must be ' 'specified via DataConfig in LearnerConfig rather than be ' 'passed as instantiated objects.') + if self.ddp_local_rank is not None: self.device = torch.device('cuda', self.ddp_local_rank) @@ -1073,6 +1022,17 @@ def setup_ddp_params(self): log.info(f'DDP local rank: {self.ddp_local_rank}') def setup_training(self, loss_def_path: Optional[str] = None) -> None: + """Set up model, data, loss, optimizers and various paths. + + The exact behavior differs based on whether this method is called in + a distributed scenario. + + Args: + loss_def_path: A local path to a directory with a ``hubconf.py``. If + provided, the loss function definition is imported from here. + This is used when loading an external loss function from a + model-bundle. Defaults to ``None``. + """ cfg = self.cfg self.config_path = join(self.output_dir, 'learner-config.json') @@ -1081,12 +1041,11 @@ def setup_training(self, loss_def_path: Optional[str] = None) -> None: self.last_model_weights_path = join(self.output_dir_local, 'last-model.pth') - if self.is_ddp_process: # pragma: no cover - # model - if self.model is not None: - self.load_checkpoint() + if not self.distributed: # data self.setup_data() + # model + self.load_checkpoint() # optimization start_epoch = self.get_start_epoch() self.setup_loss(loss_def_path=loss_def_path) @@ -1096,17 +1055,16 @@ def setup_training(self, loss_def_path: Optional[str] = None) -> None: self.step_scheduler = self.build_step_scheduler(start_epoch) if self.epoch_scheduler is None: self.epoch_scheduler = self.build_epoch_scheduler(start_epoch) + self.setup_tensorboard() + return - if self.is_ddp_master: - self.setup_tensorboard() - elif self.distributed: # pragma: no cover - if self.ddp_start_method == 'fork': - self.setup_data() - else: + # DDP + if self.is_ddp_process and dist.is_initialized(): # pragma: no cover + # model + if self.model is not None: + self.load_checkpoint() # data self.setup_data() - # model - self.load_checkpoint() # optimization start_epoch = self.get_start_epoch() self.setup_loss(loss_def_path=loss_def_path) @@ -1117,7 +1075,21 @@ def setup_training(self, loss_def_path: Optional[str] = None) -> None: if self.epoch_scheduler is None: self.epoch_scheduler = self.build_epoch_scheduler(start_epoch) - self.setup_tensorboard() + if self.is_ddp_master: + self.setup_tensorboard() + else: # pragma: no cover + if self.ddp_start_method == 'fork': + self.setup_data() + + def get_start_and_end_epochs( + self, epochs: Optional[int] = None) -> Tuple[int, int]: + """Get start and end epochs given epochs.""" + start_epoch = self.get_start_epoch() + if epochs is None: + end_epoch = self.cfg.solver.num_epochs + else: + end_epoch = start_epoch + epochs + return start_epoch, end_epoch def get_start_epoch(self) -> int: """Get start epoch. @@ -1158,7 +1130,7 @@ def setup_model(self, self.model.to(device=self.device) if self.is_ddp_process: # pragma: no cover self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) - self.model = DDP(self.model, device_ids=[self.ddp_rank]) + self.model = DDP(self.model, device_ids=[self.ddp_local_rank]) self.load_init_weights(model_weights_path=model_weights_path) def build_model(self, model_def_path: Optional[str] = None) -> nn.Module: @@ -1186,14 +1158,21 @@ def setup_data(self, distributed: Optional[bool] = None): if self.train_ds is None or self.valid_ds is None: if distributed: - if self.is_ddp_master: + if self.is_ddp_local_master: train_ds, valid_ds, test_ds = self.build_datasets() + log.info(f'{self.ddp_rank=} Done.') + else: + log.info(f'{self.ddp_rank=} Waiting.') dist.barrier() - if not self.is_ddp_master: + if not self.is_ddp_local_master: train_ds, valid_ds, test_ds = self.build_datasets() + log.info(f'{self.ddp_rank=} Done.') + else: + log.info(f'{self.ddp_rank=} Waiting.') dist.barrier() else: train_ds, valid_ds, test_ds = self.build_datasets() + if self.train_ds is None: self.train_ds = train_ds if self.valid_ds is None: @@ -1201,6 +1180,7 @@ def setup_data(self, distributed: Optional[bool] = None): if self.test_ds is None: self.test_ds = test_ds + log.info('Building dataloaders') self.train_dl, self.valid_dl, self.test_dl = self.build_dataloaders( distributed=distributed) @@ -1210,12 +1190,11 @@ def build_datasets(self) -> Tuple['Dataset', 'Dataset', 'Dataset']: train_ds, val_ds, test_ds = self.cfg.data.build(tmp_dir=self.tmp_dir) return train_ds, val_ds, test_ds - def build_dataset(self, split: Literal['train', 'valid', 'test'] - ) -> Tuple['Dataset', 'Dataset', 'Dataset']: + def build_dataset(self, + split: Literal['train', 'valid', 'test']) -> 'Dataset': """Build Dataset for split.""" log.info('Building %s dataset ...', split) - cfg = self.cfg - ds = cfg.data.build_dataset(split=split, tmp_dir=self.tmp_dir) + ds = self.cfg.data.build_dataset(split=split, tmp_dir=self.tmp_dir) return ds def build_dataloaders(self, distributed: Optional[bool] = None @@ -1235,15 +1214,20 @@ def build_dataloaders(self, distributed: Optional[bool] = None def build_dataloader(self, split: Literal['train', 'valid', 'test'], - distributed: bool = False) -> DataLoader: + distributed: Optional[bool] = None, + **kwargs) -> DataLoader: """Build DataLoader for split.""" + if distributed is None: + distributed = self.distributed + ds = self.get_dataset(split) if ds is None: ds = self.build_dataset(split) + batch_sz = self.cfg.solver.batch_sz num_workers = self.cfg.data.num_workers collate_fn = self.get_collate_fn() - sampler = self.build_sampler(split, distributed=distributed) + sampler = self.build_sampler(ds, split, distributed=distributed) if distributed: world_sz = self.ddp_world_size @@ -1263,6 +1247,7 @@ def build_dataloader(self, pin_memory=True, multiprocessing_context='fork' if distributed else None, ) + args.update(**kwargs) if sampler is not None: args['sampler'] = sampler @@ -1287,22 +1272,23 @@ def get_collate_fn(self) -> Optional[callable]: return None def build_sampler(self, + ds: 'Dataset', split: Literal['train', 'valid', 'test'], distributed: bool = False) -> Optional['Sampler']: - """Return an optional sampler for the split's dataloader.""" + """Build an optional sampler for the split's dataloader.""" split = split.lower() sampler = None if split == 'train': if distributed: sampler = DistributedSampler( - self.train_ds, + ds, shuffle=True, num_replicas=self.ddp_world_size, rank=self.ddp_rank) elif split == 'valid': if distributed: sampler = DistributedSampler( - self.valid_ds, + ds, shuffle=False, num_replicas=self.ddp_world_size, rank=self.ddp_rank) @@ -1455,9 +1441,7 @@ def _bundle_model(self, model_bundle_dir: str, """Save model weights and copy them to bundle dir.""" model_not_set = self.model is None if model_not_set: - self.setup_model( - model_weights_path=self.init_model_weights_path, - model_def_path=self.init_model_def_path) + self.model = self.build_model(self.init_model_def_path).cpu() self.load_checkpoint() path = join(model_bundle_dir, BUNDLE_MODEL_WEIGHTS_FILENAME) @@ -1473,6 +1457,7 @@ def _bundle_model(self, model_bundle_dir: str, if model_not_set: self.model = None + gc.collect() def export_to_onnx(self, path: str, @@ -1512,11 +1497,14 @@ def export_to_onnx(self, if sample_input is None: dl = self.valid_dl if dl is None: - dl = self.build_dataloader('valid') + dl = self.build_dataloader( + 'valid', batch_size=1, num_workers=0, distributed=False) sample_input, _ = next(iter(dl)) - torch.cuda.empty_cache() - sample_input = self.to_device(sample_input, self.device) + model_device = next(model.parameters()).device + if model_device.type == 'cuda': + torch.cuda.empty_cache() + sample_input = self.to_device(sample_input, model_device) args = dict( input_names=['x'], @@ -1582,7 +1570,30 @@ def _bundle_transforms(self, model_bundle_dir: str) -> None: ######### # Misc. ######### + def ddp(self, rank: Optional[int] = None, + world_size: Optional[int] = None) -> DDPContextManager: + """Return a :class:`DDPContextManager`. + + This should be used to wrap code that needs to be executed in parallel. + It is safe call this recursively; recusive calls will have no affect. + + Note that :class:`DDPContextManager` does not start processes itself, + but merely initializes and destroyes DDP process groups. + + Usage: + + .. code-block:: python + + with learner.ddp([rank], [world_size]): + ... + + """ + if not self.distributed: + raise ValueError('self.distributed is False') + return DDPContextManager(self, rank, world_size) + def reduce_distributed_metrics(self, metrics: dict): # pragma: no cover + """Average numeric metrics across processes.""" for k in metrics.keys(): v = metrics[k] if isinstance(v, (float, int)): @@ -1613,7 +1624,7 @@ def to_batch(self, x: Tensor) -> Tensor: x = x[None, ...] return x - def to_device(self, x: Any, device: str) -> Any: + def to_device(self, x: Any, device: Union[str, torch.device]) -> Any: """Load Tensors onto a device. Args: diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/__init__.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/__init__.py index e1b24f2bff..f5a1903918 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/__init__.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/__init__.py @@ -2,6 +2,7 @@ from rastervision.pytorch_learner.utils.utils import * from rastervision.pytorch_learner.utils.torch_hub import * +from rastervision.pytorch_learner.utils.distributed import * __all__ = [ SplitTensor.__name__, @@ -21,4 +22,6 @@ torch_hub_load_github.__name__, torch_hub_load_uri.__name__, torch_hub_load_local.__name__, + DDPContextManager.__name__, + 'DDP_BACKEND', ] diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/distributed.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/distributed.py new file mode 100644 index 0000000000..17c29eb1f8 --- /dev/null +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/utils/distributed.py @@ -0,0 +1,98 @@ +from typing import TYPE_CHECKING, Any, Optional +import os +from contextlib import AbstractContextManager +import gc +import logging + +import torch +import torch.distributed as dist + +from rastervision.pipeline import rv_config_ as rv_config + +if TYPE_CHECKING: + from rastervision.pytorch_learner import Learner + +log = logging.getLogger(__name__) + +DDP_BACKEND = rv_config.get_namespace_option('rastervision', 'DDP_BACKEND', + 'nccl') + + +class DDPContextManager(AbstractContextManager): # pragma: no cover + """Context manager for initializing and destroying DDP process groups. + + Note that this context manager does not start processes itself, but + merely calls :func:`torch.distributed.init_process_group` and + :func:`torch.distributed.destroy_process_group` and sets DDP-related fields + in the :class:`Learner` to appropriate values. + + If a process group is already initialized, this context manager does + nothing on either entry or exit. + """ + + def __init__(self, + learner: 'Learner', + rank: Optional[int] = None, + world_size: Optional[int] = None) -> None: + """Constructor. + + Args: + learner: The :class:`Learner` on which to set DDP-related fields. + rank: The process rank. If ``None``, will be set to + ``Learner.ddp_rank``. Defaults to ``None``. + world_size: The world size. If ``None``, will be set to + ``Learner.ddp_world_size``. Defaults to ``None``. + + Raises: + ValueError: If ``rank`` or ``world_size`` not provided and aren't + set on the :class:`Learner`. + """ + self.learner = learner + self.rank = learner.ddp_rank if rank is None else rank + self.world_size = (learner.ddp_world_size + if world_size is None else world_size) + if self.rank is None or self.world_size is None: + raise ValueError('Could not determine rank and world_size.') + self.noop = dist.is_initialized() + + def __enter__(self) -> Any: + if self.noop: + return + + learner = self.learner + rank = self.rank + world_size = self.world_size + + os.environ['MASTER_ADDR'] = os.environ.get('MASTER_ADDR', 'localhost') + os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '12355') + + log.debug('Calling init_process_group()') + dist.init_process_group(DDP_BACKEND, rank=rank, world_size=world_size) + + if learner.ddp_rank is None: + learner.ddp_rank = rank + if learner.ddp_local_rank is None: + # Implies process was spawned by learner.train(), and therefore, + # this is necessarily a single-node multi-GPU scenario. + # So global rank == local rank. + learner.ddp_local_rank = rank + if learner.ddp_world_size is None: + learner.ddp_world_size = world_size + + log.info('DDP rank: %d, DDP local rank: %d', learner.ddp_rank, + learner.ddp_local_rank) + + learner.is_ddp_process = True + learner.is_ddp_master = learner.ddp_rank == 0 + learner.is_ddp_local_master = learner.ddp_local_rank == 0 + + learner.device = torch.device(learner.device.type, + learner.ddp_local_rank) + torch.cuda.set_device(learner.device) + + def __exit__(self, exc_type, exc_value, traceback): + if self.noop: + return + dist.barrier() + dist.destroy_process_group() + gc.collect()