Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added check for apex AMP and unit tests for Horovod + AMP #3404

Merged
merged 3 commits into from
Sep 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed setting batch size in `LightningModule.datamodule` when using `auto_scale_batch_size` ([#3266](https://github.com/PyTorchLightning/pytorch-lightning/pull/3266))

- Fixed Horovod distributed backend compatibility with native AMP ([#3404](https://github.com/PyTorchLightning/pytorch-lightning/pull/3404))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ def setup(self, model):
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]

if self.trainer.amp_backend:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
for optimizer in self.trainer.optimizers:
Expand All @@ -92,6 +87,11 @@ def filter_named_parameters(model, optimizer):
for optimizer in self.trainer.optimizers
]

if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.trainer.global_rank = hvd.rank()
Expand Down
25 changes: 23 additions & 2 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def _nccl_available():
def _run_horovod(trainer_options, on_gpu=False):
"""Execute the training script across multiple workers in parallel."""
num_processes = trainer_options.get('gpus', 2)
# gpus trainer argument does not apply for horovod
trainer_options.update(gpus=None)
# for Horovod, we interpret `gpus` to be set per worker
trainer_options.update(gpus=1 if on_gpu else None)
tutils.reset_seed()
cmdline = [
'horovodrun',
Expand Down Expand Up @@ -110,6 +110,27 @@ def test_horovod_multi_gpu(tmpdir):
_run_horovod(trainer_options, on_gpu=True)


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_horovod_amp(tmpdir):
"""Test Horovod with multi-GPU support."""
trainer_options = dict(
default_root_dir=str(tmpdir),
weights_save_path=str(tmpdir),
gradient_clip_val=1.0,
progress_bar_refresh_rate=0,
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
deterministic=True,
distributed_backend='horovod',
precision=16,
)
_run_horovod(trainer_options, on_gpu=True)


@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not _nccl_available(), reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
Expand Down