Skip to content

Commit

Permalink
Added check for apex AMP and unit tests for Horovod + AMP (#3404)
Browse files Browse the repository at this point in the history
* Added check for apex AMP and unit tests for Horovod + AMP

* Changelog

* Fixed order of Horovod and Apex optimizer wrapping
  • Loading branch information
tgaddair authored Sep 9, 2020
1 parent aaf26d7 commit 091d37f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed setting batch size in `LightningModule.datamodule` when using `auto_scale_batch_size` ([#3266](https://github.com/PyTorchLightning/pytorch-lightning/pull/3266))

- Fixed Horovod distributed backend compatibility with native AMP ([#3404](https://github.com/PyTorchLightning/pytorch-lightning/pull/3404))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ def setup(self, model):
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]

if self.trainer.amp_backend:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
for optimizer in self.trainer.optimizers:
Expand All @@ -92,6 +87,11 @@ def filter_named_parameters(model, optimizer):
for optimizer in self.trainer.optimizers
]

if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.trainer.global_rank = hvd.rank()
Expand Down
25 changes: 23 additions & 2 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _nccl_available():
def _run_horovod(trainer_options, on_gpu=False):
"""Execute the training script across multiple workers in parallel."""
num_processes = trainer_options.get('gpus', 2)
# gpus trainer argument does not apply for horovod
trainer_options.update(gpus=None)
# for Horovod, we interpret `gpus` to be set per worker
trainer_options.update(gpus=1 if on_gpu else None)
tutils.reset_seed()
cmdline = [
'horovodrun',
Expand Down Expand Up @@ -110,6 +110,27 @@ def test_horovod_multi_gpu(tmpdir):
_run_horovod(trainer_options, on_gpu=True)


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_horovod_amp(tmpdir):
"""Test Horovod with multi-GPU support."""
trainer_options = dict(
default_root_dir=str(tmpdir),
weights_save_path=str(tmpdir),
gradient_clip_val=1.0,
progress_bar_refresh_rate=0,
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
deterministic=True,
distributed_backend='horovod',
precision=16,
)
_run_horovod(trainer_options, on_gpu=True)


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
Expand Down

0 comments on commit 091d37f

Please sign in to comment.