Skip to content

Commit

Permalink
ref: refactored horovod backend (#3121)
Browse files Browse the repository at this point in the history
* refactored horovod backend

* refactored horovod backend
  • Loading branch information
williamFalcon authored Aug 24, 2020
1 parent 8d7ca5c commit 8ebf4fe
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 68 deletions.
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from pytorch_lightning.accelerators.dp_backend import DataParallelBackend
from pytorch_lightning.accelerators.gpu_backend import GPUBackend
from pytorch_lightning.accelerators.tpu_backend import TPUBackend
from pytorch_lightning.accelerators.horovod_backend import HorovodBackend
115 changes: 115 additions & 0 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
import torch
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.optim.lr_scheduler import _LRScheduler

try:
from apex import amp
except ImportError:
amp = None


try:
import horovod.torch as hvd
except (ModuleNotFoundError, ImportError):
HOROVOD_AVAILABLE = False
else:
HOROVOD_AVAILABLE = True


class HorovodBackend(Accelerator):
amp_backend: AMPType

def __init__(self, trainer):
super().__init__(trainer)

def setup(self, model):
# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

if torch.cuda.is_available() and self.trainer.on_gpu:
# Horovod: pin GPU to local rank
assert self.trainer.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

# avoid duplicating progress bar
if hvd.rank() != 0 and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# Horovod: scale the learning rate by the number of workers to account for
# increased total batch size
for optimizer in self.trainer.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] *= hvd.size()

# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
for scheduler in self.trainer.lr_schedulers:
scheduler = scheduler['scheduler']
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]

if self.trainer.amp_backend:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
for optimizer in self.trainer.optimizers:
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

def filter_named_parameters(model, optimizer):
opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])])
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

# Horovod: wrap optimizers to perform gradient aggregation via allreduce
self.trainer.optimizers = [
hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters(model, optimizer))
for optimizer in self.trainer.optimizers
]

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.trainer.global_rank = hvd.rank()
rank_zero_only.rank = self.trainer.global_rank

self.trainer.model = model

def train(self):
with ExitStack() as stack:
for optimizer in self.trainer.optimizers:
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

result = self.trainer.run_pretrain_routine(self.trainer.model)

# Make sure all workers have finished training before returning to the user
hvd.join()
return result

def teardown(self):
pass
66 changes: 0 additions & 66 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,72 +157,6 @@ def __transfer_batch_to_device(self, batch: Any, device: torch.device):
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

def horovod_train(self, model):
# call setup after the ddp process has connected
self.call_setup_hook(model)

if torch.cuda.is_available() and self.on_gpu:
# Horovod: pin GPU to local rank
assert self.root_gpu == hvd.local_rank()
torch.cuda.set_device(self.root_gpu)
model.cuda(self.root_gpu)

# avoid duplicating progress bar
if hvd.rank() != 0 and self.progress_bar_callback is not None:
self.progress_bar_callback.disable()

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

# Horovod: scale the learning rate by the number of workers to account for
# increased total batch size
for optimizer in self.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] *= hvd.size()

# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
for scheduler in self.lr_schedulers:
scheduler = scheduler['scheduler']
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]

if self.amp_backend:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
for optimizer in self.optimizers:
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

def filter_named_parameters(model, optimizer):
opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])])
return [(name, p) for name, p in model.named_parameters() if p in opt_params]

# Horovod: wrap optimizers to perform gradient aggregation via allreduce
self.optimizers = [
hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters(model, optimizer))
for optimizer in self.optimizers
]

# Update logger rank info from Horovod to avoid race conditions from different ranks
# creating directories / writing files in the same locations.
self.global_rank = hvd.rank()
rank_zero_only.rank = self.global_rank

with ExitStack() as stack:
for optimizer in self.optimizers:
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

result = self.run_pretrain_routine(model)

# Make sure all workers have finished training before returning to the user
hvd.join()
return result


def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:
if isinstance(s, str):
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.utils.data import DataLoader

from pytorch_lightning.accelerators import (
GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend)
GPUBackend, TPUBackend, CPUBackend, DDPSpawnBackend, DataParallelBackend, DDPBackend, DDP2Backend, HorovodBackend)
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -1066,7 +1066,10 @@ def fit(
self.accelerator_backend.teardown()

elif self.use_horovod:
results = self.horovod_train(model)
self.accelerator_backend = HorovodBackend(self)
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()

elif self.use_single_gpu:
self.accelerator_backend = GPUBackend(self)
Expand Down

0 comments on commit 8ebf4fe

Please sign in to comment.