Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerator Refactor: Precision Plugins #5718

Merged
merged 14 commits into from
Jan 31, 2021
42 changes: 22 additions & 20 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins import TrainingTypePlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin, HorovodPlugin
from pytorch_lightning.plugins.precision import (
PrecisionPlugin,
MixedPrecisionPlugin,
ApexMixedPrecisionPlugin,
NativeMixedPrecisionPlugin,
)
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.enums import AMPType, LightningEnum


class Accelerator(object):
Expand All @@ -39,7 +45,7 @@ class Accelerator(object):

def __init__(
self,
precision_plugin, #: PrecisionPlugin # fixme
precision_plugin: PrecisionPlugin,
training_type_plugin: TrainingTypePlugin,
) -> None:
"""
Expand Down Expand Up @@ -230,9 +236,8 @@ def backward(
)

# TODO: this is a hack, find a better solution for this (hook?)
# fixme: uncomment when this class is added
# if isinstance(self.training_type_plugin, HorovodPlugin):
# optimizer.synchronize()
if isinstance(self.training_type_plugin, HorovodPlugin):
optimizer.synchronize()

return output

Expand All @@ -256,11 +261,9 @@ def optimizer_step(
"""
model_ref = self.lightning_module
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
# fixme: uncomment when this class is added
# is_native_amp = (
# isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE
# )
is_native_amp = False
native_amp = (
isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE
)

self.precision_plugin.pre_optimizer_step(optimizer, opt_idx)
self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx)
Expand All @@ -273,7 +276,7 @@ def optimizer_step(
optimizer_idx=opt_idx,
optimizer_closure=lambda_closure,
on_tpu=False, # TPUAccelerator class sets this as True
using_native_amp=is_native_amp,
using_native_amp=native_amp,
using_lbfgs=is_lbfgs,
)

Expand Down Expand Up @@ -326,7 +329,7 @@ def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: Lightn
"""
plugin.connect(model)

def connect_precision_plugin(self, plugin): #: PrecisionPlugin # fixme
def connect_precision_plugin(self, plugin: PrecisionPlugin):
"""Attaches the precision plugin to the accelerator"""
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
self.model = model
Expand All @@ -339,13 +342,12 @@ def to_device(self, batch: Any) -> Any:

@property
def amp_backend(self) -> Optional[LightningEnum]:
# fixme: uncomment when this class is added
# if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
# return AMPType.APEX
# elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
# return AMPType.NATIVE
# return None
pass
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
return AMPType.APEX
elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
return AMPType.NATIVE
else:
return None

@property
def precision(self) -> int:
Expand Down
24 changes: 13 additions & 11 deletions pytorch_lightning/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Generator, Optional, overload, Sequence, Tuple

import torch


class Plugin(object):
class Plugin(ABC):
"""Basic Plugin class to derive precision and training type plugins from."""

def connect(self, model: torch.nn.Module, *args, **kwargs):
@abstractmethod
def connect(self, model: torch.nn.Module, *args: Sequence, **kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]:
"""Connects the plugin with the accelerator (and thereby with trainer and model).
Will be called by the accelerator.
"""
pass

def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int):
def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Hook to do something before each optimizer step."""
pass
Borda marked this conversation as resolved.
Show resolved Hide resolved

def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int):
def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""
pass

def pre_training(self):
def pre_training(self) -> None:
"""Hook to do something before the training starts."""
pass

def post_training(self):
def post_training(self) -> None:
"""Hook to do something after the training finishes."""
pass

@contextlib.contextmanager
def train_step_context(self):
def train_step_context(self) -> Generator:
"""A contextmanager for the trainstep"""
yield

@contextlib.contextmanager
def val_step_context(self):
def val_step_context(self) -> Generator:
"""A contextmanager for the validation step"""
yield

@contextlib.contextmanager
def test_step_context(self):
def test_step_context(self) -> Generator:
"""A contextmanager for the teststep"""
yield
yield
7 changes: 6 additions & 1 deletion pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@

from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu_bfloat import TPUHalfPrecisionPlugin
146 changes: 146 additions & 0 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple

import torch
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, rank_zero_warn

if _APEX_AVAILABLE:
from apex import amp


class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

def __init__(self, amp_level: str):
self.backend = AMPType.APEX
self.amp_level = amp_level

def master_params(self, optimizer: torch.optim.Optimizer):
return amp.master_params(optimizer)

def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
"""Connects the precision plugin to the training process,
configures apex and reinits the schedulers
"""
model, optimizers = self.configure_apex(amp, model, optimizers, self.amp_level)
self.reinit_scheduler_properties(optimizers, lr_schedulers)
return model, optimizers, lr_schedulers

def backward(
self,
model: LightningModule,
closure_loss: torch.Tensor,
optimizer: torch.optim.Optimizer,
opt_idx: int,
should_accumulate: bool,
*args,
**kwargs,
):
"""performs the actual backpropagation

Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: the optimizer to perform the step lateron
opt_idx: the optimizer's index
should_accumulate: whether to accumulate gradients or not

"""
closure_loss = amp.scale_loss(closure_loss, optimizer)

# enter apex context
context = closure_loss
closure_loss = closure_loss.__enter__()

# do backward pass
# TODO: not entirely sure, why we need this
if model is not None and isinstance(model, LightningModule):
model.backward(closure_loss, optimizer, opt_idx)
else:
closure_loss.backward(*args, **kwargs)

# exit amp context
a, b, c = None, None, None
error = context.__exit__(a, b, c)
if error:
rank_zero_warn(a, b, c)
raise Exception("apex unscale error")

# once backward has been applied, release graph
closure_loss = closure_loss.detach()
return closure_loss

def configure_apex(
self,
amp: object,
model: LightningModule,
optimizers: List[Optimizer],
amp_level: str,
) -> Tuple[LightningModule, List[Optimizer]]:
r"""
Override to init AMP your own way.
Must return a model and list of optimizers.

Args:
amp: pointer to amp library object.
model: pointer to current :class:`LightningModule`.
optimizers: list of optimizers passed in :meth:`configure_optimizers`.
amp_level: AMP mode chosen ('O1', 'O2', etc...)

Return:
Apex wrapped model and optimizers

Examples:
.. code-block:: python

# Default implementation used by Trainer.
def configure_apex(self, amp, model, optimizers, amp_level):
model, optimizers = amp.initialize(
model, optimizers, opt_level=amp_level,
)

return model, optimizers
"""
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
return model, optimizers

@staticmethod
def reinit_scheduler_properties(optimizers: list, schedulers: list):
"""Reinitializes schedulers with correct properties"""
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
scheduler = scheduler["scheduler"]

for optimizer in optimizers:
state = None
idx = 0

# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if mro in (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
idx = i
state = scheduler.state_dict()
else:
state = None

scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)
23 changes: 23 additions & 0 deletions pytorch_lightning/plugins/precision/mixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import AMPType


class MixedPrecisionPlugin(PrecisionPlugin):
"""Base Class for mixed precision"""

EPSILON = 1e-5
backend: AMPType
precision = "mixed"
Borda marked this conversation as resolved.
Show resolved Hide resolved
Loading