From 7024177f7d4163e2af57c7dfa5550585748cbc0d Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Wed, 22 Apr 2020 14:39:08 -0700 Subject: [PATCH] Added Horovod distributed backend (#1529) * Initial commit of Horovod distributed backend implementation * Update distrib_data_parallel.py * Update distrib_data_parallel.py * Update tests/models/test_horovod.py Co-Authored-By: Jirka Borovec * Update tests/models/test_horovod.py Co-Authored-By: Jirka Borovec * Fixed tests * Added six * tests * Install tox for GitHub CI * Retry tests * Catch all exceptions * Skip cache * Remove tox * Restore pip cache * Remove the cache * Restore pip cache * Remove AMP Co-authored-by: William Falcon Co-authored-by: Jirka Borovec Co-authored-by: J. Borovec --- .circleci/config.yml | 1 + .drone.yml | 13 +- .github/workflows/ci-testing.yml | 11 +- CHANGELOG.md | 1 + docs/source/multi_gpu.rst | 38 ++++ pytorch_lightning/trainer/data_loading.py | 18 +- .../trainer/distrib_data_parallel.py | 35 +++- pytorch_lightning/trainer/distrib_parts.py | 69 ++++++++ pytorch_lightning/trainer/evaluation_loop.py | 14 ++ pytorch_lightning/trainer/trainer.py | 3 + pytorch_lightning/trainer/training_loop.py | 24 ++- requirements-extra.txt | 3 +- tests/README.md | 1 + tests/base/models.py | 138 +++++++++++++++ tests/base/utils.py | 36 ++-- .../data/horovod/train_default_model.py | 55 ++++++ tests/models/test_horovod.py | 163 ++++++++++++++++++ tests/requirements.txt | 1 - 18 files changed, 597 insertions(+), 27 deletions(-) create mode 100644 tests/models/data/horovod/train_default_model.py create mode 100644 tests/models/test_horovod.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 1be2cf2e2717c..c005d5bb2165f 100755 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -10,6 +10,7 @@ references: run: name: Install Dependences command: | + sudo apt-get update && sudo apt-get install -y cmake pip install "$TORCH_VERSION" pip install -r requirements.txt -q sudo pip install pytest pytest-cov pytest-flake8 -q diff --git a/.drone.yml b/.drone.yml index 90c8873b77945..7c30e8052533c 100644 --- a/.drone.yml +++ b/.drone.yml @@ -6,12 +6,19 @@ name: torch-GPU steps: - name: testing - image: pytorch/pytorch:1.4-cuda10.1-cudnn7-runtime + image: pytorch/pytorch:1.4-cuda10.1-cudnn7-devel environment: SLURM_LOCALID: 0 CODECOV_TOKEN: from_secret: codecov_token + HOROVOD_GPU_ALLREDUCE: NCCL + HOROVOD_GPU_BROADCAST: NCCL + HOROVOD_WITH_PYTORCH: 1 + HOROVOD_WITHOUT_TENSORFLOW: 1 + HOROVOD_WITHOUT_MXNET: 1 + HOROVOD_WITH_GLOO: 1 + HOROVOD_WITHOUT_MPI: 1 #volumes: # # Mount pip cache from host @@ -19,11 +26,13 @@ steps: # path: /opt/conda/lib/python3.7/site-packages commands: + - export PATH="$PATH:/root/.local/bin" - python --version - pip install pip -U - pip --version - nvidia-smi - - bash ./tests/install_AMP.sh +# - bash ./tests/install_AMP.sh + - apt-get update && apt-get install -y cmake - pip install -r requirements.txt --user -q - pip install coverage pytest pytest-cov pytest-flake8 codecov -q - pip install -r ./tests/requirements.txt --user -q diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ba438fda6152a..defc84e885df1 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -41,9 +41,15 @@ jobs: if: runner.os == 'macOS' run: | brew install libomp # https://github.com/pytorch/pytorch/issues/20030 + brew install openmpi # Horovod on macOS requires OpenMPI, Gloo not currently supported - # TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved - name: Setup Windows + if: runner.os == 'windows' + run: | + python -c "lines = [line for line in open('requirements-extra.txt').readlines() if not line.startswith('horovod')] ; open('requirements-extra.txt', 'w').writelines(lines)" + + # TODO: remove after https://github.com/pytorch/pytorch/issues/32186 is resolved + - name: Setup Windows on Latest if: runner.os == 'windows' && matrix.requires == 'latest' run: | python -c "req = open('requirements.txt').read().replace('torch>=1.1', 'torch<1.5') ; open('requirements.txt', 'w').write(req)" @@ -75,11 +81,12 @@ jobs: run: | # python -m pip install --upgrade --user pip pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q - pip install -r ./tests/requirements.txt -q + HOROVOD_BUILD_ARCH_FLAGS="-mfma" pip install -r ./tests/requirements.txt -q # pip install tox coverage python --version pip --version pip list + shell: bash - name: Tests # env: diff --git a/CHANGELOG.md b/CHANGELOG.md index 7529a35c3ef31..0dced0c148cdd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `ddp_cpu` backend for testing ddp without GPUs ([#1158](https://github.com/PyTorchLightning/pytorch-lightning/pull/1158)) +- Added [Horovod](http://horovod.ai) support as a distributed backend `Trainer(distributed_backend='horovod')` ([#1529](https://github.com/PyTorchLightning/pytorch-lightning/pull/1529)) ### Changed diff --git a/docs/source/multi_gpu.rst b/docs/source/multi_gpu.rst index 6b8f15b736443..ab6669e62f9f7 100644 --- a/docs/source/multi_gpu.rst +++ b/docs/source/multi_gpu.rst @@ -76,6 +76,7 @@ Lightning allows multiple ways of training - Data Parallel (`distributed_backend='dp'`) (multiple-gpus, 1 machine) - DistributedDataParallel (`distributed_backend='ddp'`) (multiple-gpus across many machines). - DistributedDataParallel2 (`distributed_backend='ddp2'`) (dp in a machine, ddp across machines). +- Horovod (`distributed_backend='horovod'`) (multi-machine, multi-gpu, configured at runtime) - TPUs (`num_tpu_cores=8|x`) (tpu or TPU pod) Data Parallel (dp) @@ -136,6 +137,43 @@ In this case, we can use ddp2 which behaves like dp in a machine and ddp across # train on 32 GPUs (4 nodes) trainer = pl.Trainer(gpus=8, distributed_backend='ddp2', num_nodes=4) +Horovod +^^^^^^^ +`Horovod `_ allows the same training script to be used for single-GPU, +multi-GPU, and multi-node training. + +Like Distributed Data Parallel, every process in Horovod operates on a single GPU with a fixed +subset of the data. Gradients are averaged across all GPUs in parallel during the backward pass, +then synchronously applied before beginning the next step. + +The number of worker processes is configured by a driver application (`horovodrun` or `mpirun`). In +the training script, Horovod will detect the number of workers from the environment, and automatically +scale the learning rate to compensate for the increased total batch size. + +Horovod can be configured in the training script to run with any number of GPUs / processes as follows: + +.. code-block:: python + + # train Horovod on GPU (number of GPUs / machines provided on command-line) + trainer = pl.Trainer(distributed_backend='horovod', gpus=1) + + # train Horovod on CPU (number of processes / machines provided on command-line) + trainer = pl.Trainer(distributed_backend='horovod') + +When starting the training job, the driver application will then be used to specify the total +number of worker processes: + +.. code-block:: bash + + # run training with 4 GPUs on a single machine + horovodrun -np 4 python train.py + + # run training with 8 GPUs on two machines (4 GPUs each) + horovodrun -np 8 -H hostname1:4,hostname2:4 python train.py + +See the official `Horovod documentation `_ for details +on installation and performance tuning. + DP/DDP2 caveats ^^^^^^^^^^^^^^^ In DP and DDP2 each GPU within a machine sees a portion of a batch. diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 5b07d26d4dd62..2dde01375c371 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -26,6 +26,13 @@ else: XLA_AVAILABLE = True +try: + import horovod.torch as hvd +except ImportError: + HOROVOD_AVAILABLE = False +else: + HOROVOD_AVAILABLE = True + def _has_len(dataloader: DataLoader) -> bool: """ Checks if a given Dataloader has __len__ method implemented i.e. if @@ -47,6 +54,7 @@ class TrainerDataLoadingMixin(ABC): proc_rank: int use_ddp: bool use_ddp2: bool + use_horovod: bool shown_warnings: ... val_check_interval: float use_tpu: bool @@ -89,7 +97,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: # don't do anything if it's not a dataloader if not isinstance(dataloader, DataLoader): return dataloader - need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_tpu) + need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu) if self.replace_sampler_ddp and need_dist_sampler: skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] @@ -104,6 +112,10 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), ) + elif self.use_horovod: + sampler = DistributedSampler(dataloader.dataset, + num_replicas=hvd.size(), + rank=hvd.rank()) else: world_size = { 'ddp': self.num_nodes * self.num_processes, @@ -254,6 +266,10 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader: # all processes wait until data download has happened torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders') + elif self.use_horovod: + # all processes wait until data download has happened + hvd.join() + return dataloader def determine_data_use_amount(self, train_percent_check: float, val_percent_check: float, diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index dadff04a1154d..bfc85ee883f6e 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -131,6 +131,13 @@ def train_fx(trial_hparams, cluster_manager, _): else: APEX_AVAILABLE = True +try: + import horovod.torch as hvd +except ImportError: + HOROVOD_AVAILABLE = False +else: + HOROVOD_AVAILABLE = True + class TrainerDDPMixin(ABC): @@ -178,10 +185,14 @@ def set_distributed_mode(self, distributed_backend): self.use_dp = False self.use_ddp = False self.use_ddp2 = False + self.use_horovod = False self.single_gpu = False if distributed_backend is None: - if self.num_gpus == 0: + if self.has_horovodrun(): + self.check_horovod() + self.use_horovod = True + elif self.num_gpus == 0: if self.num_nodes > 1 or self.num_processes > 1: self.use_ddp = True # ddp_cpu elif self.num_gpus == 1: @@ -219,6 +230,9 @@ def set_distributed_mode(self, distributed_backend): self.use_ddp = True self.data_parallel_device_ids = None self.on_gpu = False + elif distributed_backend == 'horovod': + self.check_horovod() + self.use_horovod = True # throw error to force user ddp or ddp2 choice if self.num_nodes > 1 and not (self.use_ddp2 or self.use_ddp): @@ -402,3 +416,22 @@ def resolve_root_node_address(self, root_node): root_node = name + number return root_node + + def check_horovod(self): + """Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod.""" + if not HOROVOD_AVAILABLE: + raise MisconfigurationException( + 'Requested `distributed_backend="horovod"`, but Horovod is not installed.' + 'Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]' + ) + + if self.num_gpus > 1 or self.num_nodes > 1: + raise MisconfigurationException( + 'Horovod does not support setting num_nodes / num_gpus explicitly. Use ' + 'horovodrun / mpirun to configure the number of processes.' + ) + + @staticmethod + def has_horovodrun(): + """Returns True if running with `horovodrun` using Gloo or OpenMPI.""" + return 'OMPI_COMM_WORLD_RANK' in os.environ or 'HOROVOD_RANK' in os.environ diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 5e26d49902eeb..7ce61bbfb77e6 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -337,13 +337,16 @@ """ +from contextlib import ExitStack import os from abc import ABC, abstractmethod import time import random import torch +from typing import Union from pytorch_lightning import _logger as log +from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, LightningDataParallel, @@ -365,6 +368,13 @@ else: XLA_AVAILABLE = True +try: + import horovod.torch as hvd +except ImportError: + HOROVOD_AVAILABLE = False +else: + HOROVOD_AVAILABLE = True + class TrainerDPMixin(ABC): @@ -385,6 +395,7 @@ class TrainerDPMixin(ABC): tpu_global_core_rank: int use_tpu: bool data_parallel_device_ids: ... + logger: Union[LightningLoggerBase, bool] @property @abstractmethod @@ -540,6 +551,64 @@ def dp_train(self, model): self.run_pretrain_routine(model) + def horovod_train(self, model): + # Horovod: initialize library + hvd.init() + + if torch.cuda.is_available() and self.on_gpu: + # Horovod: pin GPU to local rank + torch.cuda.set_device(hvd.local_rank()) + model.cuda(hvd.local_rank()) + + # Only show progress bar from the first worker + self.progress_bar_refresh_rate = self.progress_bar_refresh_rate if hvd.rank() == 0 else 0 + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) + + # Horovod: scale the learning rate by the number of workers to account for + # increased total batch size + for optimizer in self.optimizers: + for param_group in optimizer.param_groups: + param_group['lr'] *= hvd.size() + + if self.use_amp: + # An example + model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level) + self.optimizers = optimizers + + # Horovod: broadcast parameters & optimizer state to ensure consistent initialization + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + for optimizer in self.optimizers: + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + def filter_named_parameters(model, optimizer): + opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])]) + return [(name, p) for name, p in model.named_parameters() if p in opt_params] + + # Horovod: wrap optimizers to perform gradient aggregation via allreduce + self.optimizers = [ + hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters(model, optimizer)) + for optimizer in self.optimizers + ] + + # Update logger rank info from Horovod to avoid race conditions from different ranks + # creating directories / writing files in the same locations. + self.proc_rank = hvd.rank() + set_proc_rank(self.proc_rank) + if self.logger: + self.logger.rank = self.proc_rank + if model.logger: + model.logger.rank = self.proc_rank + + with ExitStack() as stack: + for optimizer in self.optimizers: + # Synchronization will be performed explicitly following backward() + stack.enter_context(optimizer.skip_synchronize()) + + self.run_pretrain_routine(model) + def normalize_parse_gpu_string_input(s): if isinstance(s, str): diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 862fc5948e806..a996bd7a60d70 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -145,6 +145,13 @@ else: XLA_AVAILABLE = True +try: + import horovod.torch as hvd +except ImportError: + HOROVOD_AVAILABLE = False +else: + HOROVOD_AVAILABLE = True + class TrainerEvaluationLoopMixin(ABC): @@ -153,9 +160,11 @@ class TrainerEvaluationLoopMixin(ABC): test_progress_bar: ... val_progress_bar: ... main_progress_bar: ... + on_gpu: bool use_ddp: bool use_dp: bool use_ddp2: bool + use_horovod: bool single_gpu: bool data_parallel_device_ids: ... model: LightningModule @@ -429,6 +438,11 @@ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: output = model(*args) return output + # Horovod + if self.use_horovod and self.on_gpu: + batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) + args[0] = batch + # single GPU data transfer if self.single_gpu: # for single GPU put inputs on gpu manually diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4a27069d10761..892ee7af57331 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -745,6 +745,9 @@ def fit( elif self.use_dp: self.dp_train(model) + elif self.use_horovod: + self.horovod_train(model) + elif self.single_gpu: self.single_gpu_train(model) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ed0604c74054d..5b3d13c72b5f1 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -174,6 +174,13 @@ def training_step(self, batch, batch_idx): else: XLA_AVAILABLE = True +try: + import horovod.torch as hvd +except ImportError: + HOROVOD_AVAILABLE = False +else: + HOROVOD_AVAILABLE = True + class TrainerTrainLoopMixin(ABC): @@ -181,9 +188,11 @@ class TrainerTrainLoopMixin(ABC): # the proper values/initialisation should be done in child class max_epochs: int min_epochs: int + on_gpu: bool use_ddp: bool use_dp: bool use_ddp2: bool + use_horovod: bool single_gpu: bool use_tpu: bool data_parallel_device_ids: ... @@ -324,7 +333,7 @@ def train(self): if self.reload_dataloaders_every_epoch: self.reset_train_dataloader(model) # set seed for distributed sampler (enables shuffling for each epoch) - if self.use_ddp \ + if self.use_ddp or self.use_horovod \ and hasattr(self.train_dataloader.sampler, 'set_epoch'): self.train_dataloader.sampler.set_epoch(epoch) @@ -506,6 +515,9 @@ def run_training_epoch(self): if early_stop_epoch or self.fast_dev_run: break + if self.use_horovod: + hvd.join(hvd.local_rank() if self.on_gpu else -1) + # process epoch outputs model = self.get_model() if self.is_overriden('training_epoch_end', model=model): @@ -600,6 +612,10 @@ def optimizer_closure(): self.add_tqdm_metrics(progress_bar_metrics) all_log_metrics.append(log_metrics) + if self.use_horovod: + # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid + optimizer.synchronize() + # insert after step hook if self.is_function_implemented('on_after_backward'): model_ref = self.get_model() @@ -727,6 +743,12 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens): if self.use_ddp or self.use_ddp2 or self.use_dp: output = self.model(*args) + # Horovod + elif self.use_horovod and self.on_gpu: + batch = self.transfer_batch_to_gpu(copy.copy(batch), hvd.local_rank()) + args[0] = batch + output = self.model.training_step(*args) + # single GPU forward elif self.single_gpu: gpu_id = 0 diff --git a/requirements-extra.txt b/requirements-extra.txt index aadce75abea5f..c47e130db0fd8 100644 --- a/requirements-extra.txt +++ b/requirements-extra.txt @@ -6,4 +6,5 @@ mlflow>=1.0.0 test_tube>=0.7.5 wandb>=0.8.21 trains>=0.14.1 -matplotlib>=3.1.1 \ No newline at end of file +matplotlib>=3.1.1 +horovod[pytorch]>=0.19.1 \ No newline at end of file diff --git a/tests/README.md b/tests/README.md index 0773c717d97eb..dd35709d8aeff 100644 --- a/tests/README.md +++ b/tests/README.md @@ -27,6 +27,7 @@ To test models that require GPU make sure to run the above command on a GPU mach The GPU machine must have: 1. At least 2 GPUs. 2. [NVIDIA-apex](https://github.com/NVIDIA/apex#linux) installed. +3. [Horovod with NCCL](https://horovod.readthedocs.io/en/stable/gpus_include.html) support: `HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL pip install horovod` ## Running Coverage diff --git a/tests/base/models.py b/tests/base/models.py index c130325c2fde5..66cf7ca2beda9 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -1,6 +1,7 @@ from collections import OrderedDict from typing import Dict +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -152,3 +153,140 @@ def _dataloader(self, train): ) return loader + + +class Generator(nn.Module): + def __init__(self, latent_dim, img_shape): + super().__init__() + self.img_shape = img_shape + + def block(in_feat, out_feat, normalize=True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *block(latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), + nn.Tanh() + ) + + def forward(self, z): + img = self.model(z) + img = img.view(img.size(0), *self.img_shape) + return img + + +class Discriminator(nn.Module): + def __init__(self, img_shape): + super().__init__() + + self.model = nn.Sequential( + nn.Linear(int(np.prod(img_shape)), 512), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(512, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 1), + nn.Sigmoid(), + ) + + def forward(self, img): + img_flat = img.view(img.size(0), -1) + validity = self.model(img_flat) + + return validity + + +class TestGAN(LightningModule): + """Implements a basic GAN for the purpose of illustrating multiple optimizers.""" + + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + + # networks + mnist_shape = (1, 28, 28) + self.generator = Generator(latent_dim=hparams.hidden_dim, img_shape=mnist_shape) + self.discriminator = Discriminator(img_shape=mnist_shape) + + # cache for generated images + self.generated_imgs = None + self.last_imgs = None + + def forward(self, z): + return self.generator(z) + + def adversarial_loss(self, y_hat, y): + return F.binary_cross_entropy(y_hat, y) + + def training_step(self, batch, batch_idx, optimizer_idx=None): + imgs, _ = batch + self.last_imgs = imgs + + # train generator + if optimizer_idx == 0: + # sample noise + z = torch.randn(imgs.shape[0], self.hparams.hidden_dim) + z = z.type_as(imgs) + + # generate images + self.generated_imgs = self(z) + + # ground truth result (ie: all fake) + # put on GPU because we created this tensor inside training_loop + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + # adversarial loss is binary cross-entropy + g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) + tqdm_dict = {'g_loss': g_loss} + output = OrderedDict({ + 'loss': g_loss, + 'progress_bar': tqdm_dict, + 'log': tqdm_dict + }) + return output + + # train discriminator + if optimizer_idx == 1: + # Measure discriminator's ability to classify real from generated samples + + # how well can it label as real? + valid = torch.ones(imgs.size(0), 1) + valid = valid.type_as(imgs) + + real_loss = self.adversarial_loss(self.discriminator(imgs), valid) + + # how well can it label as fake? + fake = torch.zeros(imgs.size(0), 1) + fake = fake.type_as(fake) + + fake_loss = self.adversarial_loss( + self.discriminator(self.generated_imgs.detach()), fake) + + # discriminator loss is the average of these + d_loss = (real_loss + fake_loss) / 2 + tqdm_dict = {'d_loss': d_loss} + output = OrderedDict({ + 'loss': d_loss, + 'progress_bar': tqdm_dict, + 'log': tqdm_dict + }) + return output + + def configure_optimizers(self): + lr = self.hparams.learning_rate + b1 = self.hparams.b1 + b2 = self.hparams.b2 + + opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2)) + opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2)) + return [opt_g, opt_d], [] + + def train_dataloader(self): + return DataLoader(TrialMNIST(train=True, download=True), batch_size=16) diff --git a/tests/base/utils.py b/tests/base/utils.py index 6d3be0de3ff66..1bb485f270ada 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -55,20 +55,17 @@ def run_model_test_without_loggers(trainer_options, model, min_acc=0.50): trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() -def run_model_test(trainer_options, model, on_gpu=True): +def run_model_test(trainer_options, model, on_gpu=True, version=None, with_hpc=True): save_dir = trainer_options['default_root_dir'] # logger file to get meta - logger = get_default_logger(save_dir) + logger = get_default_logger(save_dir, version=version) + trainer_options.update(logger=logger) - # logger file to get weights - checkpoint = init_checkpoint_callback(logger) - - # add these to the trainer options - trainer_options.update( - checkpoint_callback=checkpoint, - logger=logger, - ) + if 'checkpoint_callback' not in trainer_options: + # logger file to get weights + checkpoint = init_checkpoint_callback(logger) + trainer_options.update(checkpoint_callback=checkpoint) # fit model trainer = Trainer(**trainer_options) @@ -87,15 +84,16 @@ def run_model_test(trainer_options, model, on_gpu=True): [run_prediction(dataloader, pretrained_model) for dataloader in test_loaders] - if trainer.use_ddp or trainer.use_ddp2: - # on hpc this would work fine... but need to hack it for the purpose of the test - trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \ - trainer.init_optimizers(pretrained_model) + if with_hpc: + if trainer.use_ddp or trainer.use_ddp2: + # on hpc this would work fine... but need to hack it for the purpose of the test + trainer.model = pretrained_model + trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \ + trainer.init_optimizers(pretrained_model) - # test HPC loading / saving - trainer.hpc_save(save_dir, logger) - trainer.hpc_load(save_dir, on_gpu=on_gpu) + # test HPC loading / saving + trainer.hpc_save(save_dir, logger) + trainer.hpc_load(save_dir, on_gpu=on_gpu) def get_default_hparams(continue_training=False, hpc_exp_number=0): @@ -110,6 +108,8 @@ def get_default_hparams(continue_training=False, hpc_exp_number=0): 'data_root': PATH_DATASETS, 'out_features': 10, 'hidden_dim': 1000, + 'b1': 0.5, + 'b2': 0.999, } if continue_training: diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py new file mode 100644 index 0000000000000..6c11e2ca5e755 --- /dev/null +++ b/tests/models/data/horovod/train_default_model.py @@ -0,0 +1,55 @@ +""" +This script is meant to be executed from `../../test_horovod.py`. + +Because Horovod uses a parallel programming model similar to MPI, unit tests for collective +ops like allreduce need to be run in parallel. The most common approach for running parallel +Horovod workers is to launch multiple replicas of the training script via the `horovodrun` +command-line tool: + +.. code-block:: bash + + horovodrun -np 2 python train_default_model.py ... + +Individual test parameters are configured by the serialized `--trainer-options` JSON object. + +An non-zero exit code from this script on any rank will indicate failure, while a zero exit code +across all ranks indicates success. +""" + +import argparse +import json +import os +import sys + +import horovod.torch as hvd + +PATH_HERE = os.path.abspath(os.path.dirname(__file__)) +PATH_ROOT = os.path.join(PATH_HERE, '..', '..', '..', '..') +sys.path.insert(0, os.path.abspath(PATH_ROOT)) + +from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402 +import tests.base.utils as tutils # noqa: E402 + + +parser = argparse.ArgumentParser() +parser.add_argument('--trainer-options', required=True) + + +def run_test_from_config(trainer_options): + """Trains the default model with the given config.""" + tutils.reset_seed() + tutils.set_random_master_port() + + ckpt_path = trainer_options['default_root_dir'] + trainer_options['checkpoint_callback'] = ModelCheckpoint(ckpt_path) + + model, hparams = tutils.get_default_model() + tutils.run_model_test(trainer_options, model, version=0, with_hpc=False) + + # Horovod should be initialized following training. If not, this will raise an exception. + assert hvd.size() == 2 + + +if __name__ == "__main__": + args = parser.parse_args() + run_test_from_config(json.loads(args.trainer_options)) diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py new file mode 100644 index 0000000000000..c4bcb4b81b995 --- /dev/null +++ b/tests/models/test_horovod.py @@ -0,0 +1,163 @@ +import json +import os +import platform +import shlex +import subprocess +import sys + +import pytest +import torch + +from pytorch_lightning import Trainer + +import tests.base.utils as tutils +from tests.base import LightningTestModel +from tests.base.models import TestGAN + +try: + from horovod.common.util import nccl_built +except ImportError: + HOROVOD_AVAILABLE = False +else: + HOROVOD_AVAILABLE = True + + +# This script will run the actual test model training in parallel +TEST_SCRIPT = os.path.join(os.path.dirname(__file__), 'data', 'horovod', 'train_default_model.py') + + +def _nccl_available(): + if not HOROVOD_AVAILABLE: + return False + + try: + return nccl_built() + except AttributeError: + # Horovod 0.19.1 nccl_built() does not yet work with Python 3.8: + # See: https://github.com/horovod/horovod/issues/1891 + return False + + +def _run_horovod(trainer_options): + """Execute the training script across multiple workers in parallel.""" + cmdline = ['horovodrun', '-np', '2', sys.executable, TEST_SCRIPT, + '--trainer-options', shlex.quote(json.dumps(trainer_options))] + exit_code = subprocess.call(' '.join(cmdline), shell=True, env=os.environ.copy()) + assert exit_code == 0 + + +@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +def test_horovod_cpu(tmpdir): + """Test Horovod running multi-process on CPU.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + distributed_backend='horovod' + ) + _run_horovod(trainer_options) + + +@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +def test_horovod_cpu_implicit(tmpdir): + """Test Horovod without specifying a backend, inferring from env set by `horovodrun`.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + ) + _run_horovod(trainer_options) + + +@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_horovod_multi_gpu(tmpdir): + """Test Horovod with multi-GPU support.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + gradient_clip_val=1.0, + progress_bar_refresh_rate=0, + max_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + gpus=1, + distributed_backend='horovod' + ) + _run_horovod(trainer_options) + + +@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_horovod_transfer_batch_to_gpu(tmpdir): + class TestTrainingStepModel(LightningTestModel): + def training_step(self, batch, *args, **kwargs): + x, y = batch + assert str(x.device) != 'cpu' + assert str(y.device) != 'cpu' + return super(TestTrainingStepModel, self).training_step(batch, *args, **kwargs) + + def validation_step(self, batch, *args, **kwargs): + x, y = batch + assert str(x.device) != 'cpu' + assert str(y.device) != 'cpu' + return super(TestTrainingStepModel, self).validation_step(batch, *args, **kwargs) + + hparams = tutils.get_default_hparams() + model = TestTrainingStepModel(hparams) + + trainer_options = dict( + default_root_dir=str(tmpdir), + progress_bar_refresh_rate=0, + max_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + gpus=1, + distributed_backend='horovod' + ) + tutils.run_model_test_without_loggers(trainer_options, model) + + +@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Horovod not yet supported in Python 3.8") +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +def test_horovod_multi_optimizer(tmpdir): + hparams = tutils.get_default_hparams() + model = TestGAN(hparams) + + trainer_options = dict( + default_root_dir=str(tmpdir), + progress_bar_refresh_rate=0, + max_epochs=1, + train_percent_check=0.4, + val_percent_check=0.2, + distributed_backend='horovod' + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result == 1, 'model failed to complete' + + assert len(trainer.optimizers) == 2 + for i, optimizer in enumerate(trainer.optimizers): + assert hasattr(optimizer, 'synchronize'), 'optimizer has not been wrapped into DistributedOptimizer' + + def get_model_params(model): + return set([p for p in model.parameters()]) + + def get_optimizer_params(optimizer): + return set([p for group in optimizer.param_groups for p in group.get('params', [])]) + + assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0]) + assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1]) diff --git a/tests/requirements.txt b/tests/requirements.txt index e93ccb51d79c9..d579c97fb3efe 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -2,7 +2,6 @@ -r ../requirements-extra.txt # extended list of dependencies dor development and run lint and tests -tox coverage codecov pytest>=3.0.5