Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce primitives for input/output dtype conversion in Lite Precision #14792

Merged
merged 18 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/lightning_lite/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING

import torch
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import Literal

from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.enums import AMPType
from lightning_lite.utilities.imports import _APEX_AVAILABLE
from lightning_lite.utilities.types import Steppable

Expand All @@ -30,7 +33,7 @@ class DeepSpeedPrecision(Precision):
"""Precision plugin for DeepSpeed integration.

Args:
precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16).
precision: Full precision (32), half precision (16) or bfloat16 precision (bf16).
amp_type: The mixed precision backend to use ("native" or "apex").
amp_level: The optimization level to use (O1, O2, etc...). By default it will be set to "O2"
if ``amp_type`` is set to "apex".
Expand All @@ -43,7 +46,7 @@ class DeepSpeedPrecision(Precision):
If unsupported ``precision`` is provided.
"""

def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None:
def __init__(self, precision: Literal[16, 32, "bf16"], amp_type: str, amp_level: Optional[str] = None) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if amp_type == AMPType.APEX:
if not _APEX_AVAILABLE:
raise ImportError(
Expand All @@ -53,18 +56,23 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona

amp_level = amp_level or "O2"

supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT)
if precision not in supported_precision:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
supported_precision = ("16", "32", "bf16")
if str(precision) not in supported_precision:
raise ValueError(
f"`precision={precision!r})` is not supported in DeepSpeed."
f" `precision` must be one of: {(x.value for x in supported_precision)}."
f" `precision` must be one of: {', '.join(supported_precision)}."
)

super().__init__()
self.precision = precision
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
self.amp_type = amp_type
self.amp_level = amp_level

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16, 32: torch.float32}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

def backward(self, tensor: Tensor, model: "deepspeed.DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
"""Performs back-propagation using DeepSpeed's engine."""
model.backward(tensor, *args, **kwargs)
Expand Down
5 changes: 5 additions & 0 deletions src/lightning_lite/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
from typing import Generator

import torch
from torch import Tensor

from lightning_lite.plugins.precision import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor


class DoublePrecision(Precision):
Expand All @@ -34,3 +36,6 @@ def forward_context(self) -> Generator[None, None, None]:
torch.set_default_dtype(torch.float64)
yield
torch.set_default_dtype(default_dtype)

def convert_input(self, data: Tensor) -> Tensor:
return _convert_fp_tensor(data, torch.double)
9 changes: 8 additions & 1 deletion src/lightning_lite/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS
from typing_extensions import Literal

from lightning_lite.plugins.precision import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
from lightning_lite.utilities.types import Steppable

Expand All @@ -39,7 +41,7 @@ class NativeMixedPrecision(Precision):
"""

def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
self, precision: Literal[16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
super().__init__()
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
Expand All @@ -57,6 +59,11 @@ def forward_context(self) -> Generator[None, None, None]:
with self._autocast_context_manager():
yield

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
Expand Down
16 changes: 16 additions & 0 deletions src/lightning_lite/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import contextlib
from typing import Any, Dict, Generator, Optional, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.types import _PARAMETERS, Steppable


Expand All @@ -34,6 +36,20 @@ def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
yield

def convert_input(self, data: Tensor) -> Tensor:
"""Convert model inputs (forward) to the floating point precision type of this plugin.

This is a no-op for tensors that are not of floating-point type or already have the desired type.
"""
return _convert_fp_tensor(data, torch.float32)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def convert_output(self, data: Tensor) -> Tensor:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Convert model outputs (forward) back to the default floating point precision type.

This is a no-op for tensors that are not of floating-point type or already have the desired type.
"""
return _convert_fp_tensor(data, torch.get_default_dtype())

def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> None:
"""Runs before precision plugin executes backward.

Expand Down
7 changes: 7 additions & 0 deletions src/lightning_lite/plugins/precision/tpu_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.
import os

import torch
from torch import Tensor

from lightning_lite.plugins.precision import TPUPrecision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor


class TPUBf16Precision(TPUPrecision):
Expand All @@ -25,5 +29,8 @@ def __init__(self) -> None:
super().__init__()
os.environ["XLA_USE_BF16"] = "1"

def convert_input(self, data: Tensor) -> Tensor:
return _convert_fp_tensor(data, torch.bfloat16)

def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)
10 changes: 0 additions & 10 deletions src/lightning_lite/plugins/precision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@

import torch

from lightning_lite.utilities.enums import PrecisionType


def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
if precision == PrecisionType.HALF:
return _convert_fp_tensor(tensor, torch.half)
if precision == PrecisionType.BFLOAT:
return _convert_fp_tensor(tensor, torch.bfloat16)
return tensor


def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor:
return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor
7 changes: 0 additions & 7 deletions src/lightning_lite/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@
import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

from lightning_lite.accelerators import Accelerator
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.plugins.precision import Precision
from lightning_lite.plugins.precision.utils import _fp_to_half
from lightning_lite.strategies.ddp import DDPStrategy
from lightning_lite.utilities.apply_func import apply_to_collection
from lightning_lite.utilities.distributed import log
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.rank_zero import rank_zero_info
Expand Down Expand Up @@ -369,10 +366,6 @@ def load_module_state_dict(self, module: Module, checkpoint: Mapping[str, Any])
self.module_to_device(module)
self._restore_zero_state(module, checkpoint)

def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
batch = apply_to_collection(batch, Tensor, function=_fp_to_half, precision=self.precision_plugin.precision)
return super().batch_to_device(batch, device)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy")
Expand Down
17 changes: 2 additions & 15 deletions src/lightning_lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from torch.utils.data import DataLoader

from lightning_lite.plugins import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.strategies import Strategy
from lightning_lite.utilities import move_data_to_device
from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
Expand Down Expand Up @@ -97,24 +96,12 @@ def module(self) -> nn.Module:
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward
method."""
precision = self._precision_plugin.precision
precision_to_type = {
"bf16": torch.bfloat16,
16: torch.float16,
32: torch.float32,
64: torch.float64,
}
# TODO: let the precision plugin handle the conversion
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
args, kwargs = apply_to_collection(
[args, kwargs], dtype=Tensor, function=_convert_fp_tensor, dst_type=precision_to_type[precision]
)
args, kwargs = apply_to_collection([args, kwargs], function=self._precision_plugin.convert_input, dtype=Tensor)

with self._precision_plugin.forward_context():
output = self._forward_module(*args, **kwargs)

output = apply_to_collection(
output, dtype=Tensor, function=_convert_fp_tensor, dst_type=torch.get_default_dtype()
)
output = apply_to_collection(output, function=self._precision_plugin.convert_output, dtype=Tensor)
return output

@overload
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import pytorch_lightning as pl
from lightning_lite.plugins import ClusterEnvironment
from lightning_lite.plugins.precision.utils import _fp_to_half
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.optimizer import optimizers_to_device
from lightning_lite.utilities.seed import reset_seed
Expand All @@ -41,6 +40,7 @@
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.utils import _fp_to_half
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.plugins.precision.utils import _fp_to_half
from lightning_lite.utilities.cloud_io import get_filesystem
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.strategies.utils import _fp_to_half
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls
Expand Down
12 changes: 12 additions & 0 deletions src/pytorch_lightning/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import os
from inspect import getmembers, isclass

import torch

from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.strategies import _StrategyRegistry
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.registry import _is_register_method_overridden
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
Expand All @@ -34,3 +38,11 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) ->
for _, mod in getmembers(module, isclass):
if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"):
mod.register_strategies(registry)


def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if precision == PrecisionType.HALF:
return _convert_fp_tensor(tensor, torch.half)
if precision == PrecisionType.BFLOAT:
return _convert_fp_tensor(tensor, torch.bfloat16)
return tensor