-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ref: refactored horovod backend (#3121)
* refactored horovod backend * refactored horovod backend
- Loading branch information
1 parent
8d7ca5c
commit 8ebf4fe
Showing
4 changed files
with
121 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from contextlib import ExitStack | ||
import torch | ||
from pytorch_lightning.core import LightningModule | ||
from pytorch_lightning.utilities import AMPType | ||
from pytorch_lightning.accelerators.base_backend import Accelerator | ||
from pytorch_lightning.utilities.distributed import rank_zero_only | ||
from torch.optim.lr_scheduler import _LRScheduler | ||
|
||
try: | ||
from apex import amp | ||
except ImportError: | ||
amp = None | ||
|
||
|
||
try: | ||
import horovod.torch as hvd | ||
except (ModuleNotFoundError, ImportError): | ||
HOROVOD_AVAILABLE = False | ||
else: | ||
HOROVOD_AVAILABLE = True | ||
|
||
|
||
class HorovodBackend(Accelerator): | ||
amp_backend: AMPType | ||
|
||
def __init__(self, trainer): | ||
super().__init__(trainer) | ||
|
||
def setup(self, model): | ||
# call setup after the ddp process has connected | ||
self.trainer.call_setup_hook(model) | ||
|
||
if torch.cuda.is_available() and self.trainer.on_gpu: | ||
# Horovod: pin GPU to local rank | ||
assert self.trainer.root_gpu == hvd.local_rank() | ||
torch.cuda.set_device(self.trainer.root_gpu) | ||
model.cuda(self.trainer.root_gpu) | ||
|
||
# avoid duplicating progress bar | ||
if hvd.rank() != 0 and self.trainer.progress_bar_callback is not None: | ||
self.trainer.progress_bar_callback.disable() | ||
|
||
# CHOOSE OPTIMIZER | ||
# allow for lr schedulers as well | ||
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) | ||
self.trainer.optimizers = optimizers | ||
self.trainer.lr_schedulers = lr_schedulers | ||
self.trainer.optimizer_frequencies = optimizer_frequencies | ||
|
||
# Horovod: scale the learning rate by the number of workers to account for | ||
# increased total batch size | ||
for optimizer in self.trainer.optimizers: | ||
for param_group in optimizer.param_groups: | ||
param_group['lr'] *= hvd.size() | ||
|
||
# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR | ||
for scheduler in self.trainer.lr_schedulers: | ||
scheduler = scheduler['scheduler'] | ||
if isinstance(scheduler, _LRScheduler): | ||
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] | ||
|
||
if self.trainer.amp_backend: | ||
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level) | ||
self.trainer.optimizers = optimizers | ||
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers) | ||
|
||
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization | ||
hvd.broadcast_parameters(model.state_dict(), root_rank=0) | ||
for optimizer in self.trainer.optimizers: | ||
hvd.broadcast_optimizer_state(optimizer, root_rank=0) | ||
|
||
def filter_named_parameters(model, optimizer): | ||
opt_params = set([p for group in optimizer.param_groups for p in group.get('params', [])]) | ||
return [(name, p) for name, p in model.named_parameters() if p in opt_params] | ||
|
||
# Horovod: wrap optimizers to perform gradient aggregation via allreduce | ||
self.trainer.optimizers = [ | ||
hvd.DistributedOptimizer(optimizer, named_parameters=filter_named_parameters(model, optimizer)) | ||
for optimizer in self.trainer.optimizers | ||
] | ||
|
||
# Update logger rank info from Horovod to avoid race conditions from different ranks | ||
# creating directories / writing files in the same locations. | ||
self.trainer.global_rank = hvd.rank() | ||
rank_zero_only.rank = self.trainer.global_rank | ||
|
||
self.trainer.model = model | ||
|
||
def train(self): | ||
with ExitStack() as stack: | ||
for optimizer in self.trainer.optimizers: | ||
# Synchronization will be performed explicitly following backward() | ||
stack.enter_context(optimizer.skip_synchronize()) | ||
|
||
result = self.trainer.run_pretrain_routine(self.trainer.model) | ||
|
||
# Make sure all workers have finished training before returning to the user | ||
hvd.join() | ||
return result | ||
|
||
def teardown(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters