diff --git a/CHANGELOG.md b/CHANGELOG.md index ff333c0fcabb8..7313946d7b313 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,10 +9,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added `auto_select_gpus` flag to trainer that enables automatic selection of available GPUs on exclusive mode systems. + +- Added learining rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347)) + +- Added support for ddp mode in clusters without SLURM ([#1345](https://github.com/PyTorchLightning/pytorch-lightning/issues/1345)) + - Added learning rate finder ([#1347](https://github.com/PyTorchLightning/pytorch-lightning/pull/1347)) - Added `terminate_on_nan` flag to trainer that performs a NaN check with each training iteration when set to `True`. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475)) + ### Changed - Changed the default behaviour to no longer include a NaN check with each training iteration. ([#1475](https://github.com/PyTorchLightning/pytorch-lightning/pull/1475)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 71ea357ea3944..f0c6a63d93dac 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -873,53 +873,10 @@ def configure_ddp(self, model, device_ids): ) return model - def init_ddp_connection(self, proc_rank: int, world_size: int) -> None: - r""" - Override to define your custom way of setting up a distributed environment. - - Lightning's implementation uses ``env://`` init by default and sets the first node as root. - - Args: - proc_rank: The current process rank within the node. - world_size: Number of GPUs being use across all nodes (num_nodes * num_gpus). - - Examples: - .. code-block:: python - - def init_ddp_connection(self): - # use slurm job id for the port number - # guarantees unique ports across jobs from same grid search - try: - # use the last 4 numbers in the job id as the id - default_port = os.environ['SLURM_JOB_ID'] - default_port = default_port[-4:] - - # all ports should be in the 10k+ range - default_port = int(default_port) + 15000 - - except Exception as e: - default_port = 12910 - - # if user gave a port number, use that one instead - try: - default_port = os.environ['MASTER_PORT'] - except Exception: - os.environ['MASTER_PORT'] = str(default_port) - - # figure out the root node addr - try: - root_node = os.environ['SLURM_NODELIST'].split(' ')[0] - except Exception: - root_node = '127.0.0.2' - - root_node = self.trainer.resolve_root_node_address(root_node) - os.environ['MASTER_ADDR'] = root_node - dist.init_process_group( - 'nccl', - rank=self.proc_rank, - world_size=self.world_size - ) - + def _init_slurm_connection(self) -> None: + """ + Sets up environemnt variables necessary for pytorch distributed communications + based on slurm environment. """ # use slurm job id for the port number # guarantees unique ports across jobs from same grid search @@ -948,6 +905,40 @@ def init_ddp_connection(self): root_node = self.trainer.resolve_root_node_address(root_node) os.environ['MASTER_ADDR'] = root_node + + def init_ddp_connection( + self, + proc_rank: int, + world_size: int, + is_slurm_managing_tasks: bool = True + ) -> None: + """ + Override to define your custom way of setting up a distributed environment. + + Lightning's implementation uses env:// init by default and sets the first node as root + for SLURM managed cluster. + + Args: + proc_rank: The current process rank within the node. + world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus). + is_slurm_managing_tasks: is cluster managed by SLURM. + + """ + if is_slurm_managing_tasks: + self._init_slurm_connection() + + if 'MASTER_ADDR' not in os.environ: + log.warning("MASTER_ADDR environment variable is not defined. Set as localhost") + os.environ['MASTER_ADDR'] = '127.0.0.1' + + if 'MASTER_PORT' not in os.environ: + log.warning("MASTER_PORT environment variable is not defined. Set as 12910") + os.environ['MASTER_PORT'] = '12910' + + if 'WORLD_SIZE' in os.environ and os.environ['WORLD_SIZE'] != world_size: + log.warning("WORLD_SIZE environment variable is not equal to the computed " + "world size. Ignored.") + torch_backend = "nccl" if self.trainer.on_gpu else "gloo" torch_distrib.init_process_group(torch_backend, rank=proc_rank, world_size=world_size) diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index fef107ab58533..d67ae262294b7 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -285,12 +285,13 @@ def ddp_train(self, process_idx, model): :param cluster_obj: :return: """ - # node rank using relative slurm id - # otherwise default to node rank 0 + # node rank using relative slurm id if under slurm management + # otherwise use given node rank or default to node rank 0 try: - node_id = os.environ['SLURM_NODEID'] + node_id = os.environ['SLURM_NODEID'] if self.is_slurm_managing_tasks else os.environ['NODE_RANK'] self.node_rank = int(node_id) - except Exception: + except KeyError: + log.warning("SLURM_NODEID or NODE_RANK environment variable is not defined. Set as 0.") self.node_rank = 0 # show progressbar only on progress_rank 0 @@ -317,7 +318,7 @@ def ddp_train(self, process_idx, model): # try to init for 20 times at max in case ports are taken # where to store ip_table model.trainer = self - model.init_ddp_connection(self.proc_rank, self.world_size) + model.init_ddp_connection(self.proc_rank, self.world_size, self.is_slurm_managing_tasks) # CHOOSE OPTIMIZER # allow for lr schedulers as well