Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce primitives for input/output dtype conversion in Lite Precision #14792

Merged
merged 18 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/lightning_lite/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, List, Optional, Union

import torch
from typing_extensions import Literal

from lightning_lite.accelerators import ACCELERATOR_REGISTRY
from lightning_lite.accelerators.accelerator import Accelerator
Expand Down Expand Up @@ -58,6 +59,7 @@

_PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO]
_PLUGIN_INPUT = Union[_PLUGIN, str]
_PRECISION_INPUT = Literal[16, 32, 64, "bf16"]
awaelchli marked this conversation as resolved.
Show resolved Hide resolved


class _Connector:
Expand Down Expand Up @@ -97,7 +99,7 @@ def __init__(
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: Union[int, str] = 32,
precision: _PRECISION_INPUT = 32,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
) -> None:
# 1. Parsing flags
Expand All @@ -111,7 +113,7 @@ def __init__(
# For devices: Assign gpus, ipus, etc. to the accelerator flag and devices flag
self._strategy_flag: Optional[Union[Strategy, str]] = None
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
self._precision_flag: Optional[Union[int, str]] = None
self._precision_flag: Optional[_PRECISION_INPUT] = None
self._precision_plugin_flag: Optional[Precision] = None
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
self._parallel_devices: List[Union[int, torch.device, str]] = []
Expand Down Expand Up @@ -154,7 +156,7 @@ def _check_config_and_set_final_flags(
self,
strategy: Optional[Union[str, Strategy]],
accelerator: Optional[Union[str, Accelerator]],
precision: Union[int, str],
precision: _PRECISION_INPUT,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]],
) -> None:
"""This method checks:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.connector import _Connector, _PLUGIN_INPUT
from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
from lightning_lite.plugins import Precision
from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy
from lightning_lite.strategies.strategy import TBroadcast
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
strategy: Optional[Union[str, Strategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: Union[int, str] = 32,
precision: _PRECISION_INPUT = 32,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
) -> None:
self._connector = _Connector(
Expand Down
12 changes: 10 additions & 2 deletions src/lightning_lite/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING

import torch
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import Literal

from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.imports import _APEX_AVAILABLE
from lightning_lite.utilities.types import Steppable
Expand All @@ -43,7 +46,7 @@ class DeepSpeedPrecision(Precision):
If unsupported ``precision`` is provided.
"""

def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None:
def __init__(self, precision: Literal[16, 32, "bf16"], amp_type: str, amp_level: Optional[str] = None) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if amp_type == AMPType.APEX:
if not _APEX_AVAILABLE:
raise ImportError(
Expand All @@ -65,6 +68,11 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
self.amp_type = amp_type
self.amp_level = amp_level

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16, 32: torch.float32}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

def backward(self, tensor: Tensor, model: "deepspeed.DeepSpeedEngine", *args: Any, **kwargs: Any) -> None:
"""Performs back-propagation using DeepSpeed's engine."""
model.backward(tensor, *args, **kwargs)
Expand Down
5 changes: 5 additions & 0 deletions src/lightning_lite/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from typing import Generator

import torch
from torch import Tensor
from torch.nn import Module

from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor


class DoublePrecision(Precision):
Expand All @@ -38,3 +40,6 @@ def forward_context(self) -> Generator[None, None, None]:
torch.set_default_dtype(torch.float64)
yield
torch.set_default_dtype(default_dtype)

def convert_input(self, data: Tensor) -> Tensor:
return _convert_fp_tensor(data, torch.double)
9 changes: 8 additions & 1 deletion src/lightning_lite/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS
from typing_extensions import Literal

from lightning_lite.plugins.precision.precision import Precision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_10
from lightning_lite.utilities.types import Steppable

Expand All @@ -39,7 +41,7 @@ class NativeMixedPrecision(Precision):
"""

def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
self, precision: Literal[16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
super().__init__()
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
Expand All @@ -57,6 +59,11 @@ def forward_context(self) -> Generator[None, None, None]:
with self._autocast_context_manager():
yield

def convert_input(self, data: Tensor) -> Tensor:
precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16}
dst_type = precision_to_type[self.precision]
return _convert_fp_tensor(data, dst_type)

def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None:
if self.scaler is not None:
tensor = self.scaler.scale(tensor)
Expand Down
9 changes: 9 additions & 0 deletions src/lightning_lite/plugins/precision/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import contextlib
from typing import Any, Dict, Generator, Optional, Union

import torch
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.utilities.types import _PARAMETERS, Steppable


Expand All @@ -41,6 +43,13 @@ def forward_context(self) -> Generator[None, None, None]:
"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""
yield

def convert_input(self, data: Tensor) -> Tensor:
"""Convert model inputs (forward) to the floating point precision type of this plugin.

This is a no-op for tensors that are not of floating-point type or already have the desired type.
"""
return _convert_fp_tensor(data, torch.float32)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any:
"""Runs before precision plugin executes backward.

Expand Down
7 changes: 7 additions & 0 deletions src/lightning_lite/plugins/precision/tpu_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
# limitations under the License.
import os

import torch
from torch import Tensor

from lightning_lite.plugins.precision import TPUPrecision
from lightning_lite.plugins.precision.utils import _convert_fp_tensor


class TPUBf16Precision(TPUPrecision):
Expand All @@ -25,5 +29,8 @@ def __init__(self) -> None:
super().__init__()
os.environ["XLA_USE_BF16"] = "1"

def convert_input(self, data: Tensor) -> Tensor:
return _convert_fp_tensor(data, torch.bfloat16)

def teardown(self) -> None:
os.environ.pop("XLA_USE_BF16", None)
10 changes: 0 additions & 10 deletions src/lightning_lite/plugins/precision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,6 @@

import torch

from lightning_lite.utilities.enums import PrecisionType


def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
if precision == PrecisionType.HALF:
return _convert_fp_tensor(tensor, torch.half)
if precision == PrecisionType.BFLOAT:
return _convert_fp_tensor(tensor, torch.bfloat16)
return tensor


def _convert_fp_tensor(tensor: torch.Tensor, dst_type: Union[str, torch.dtype]) -> torch.Tensor:
return tensor.to(dst_type) if torch.is_floating_point(tensor) else tensor
7 changes: 0 additions & 7 deletions src/lightning_lite/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@
import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_only
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer

from lightning_lite.accelerators import Accelerator
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.plugins.precision import Precision
from lightning_lite.plugins.precision.utils import _fp_to_half
from lightning_lite.strategies.ddp import DDPStrategy
from lightning_lite.utilities.apply_func import apply_to_collection
from lightning_lite.utilities.distributed import log
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.rank_zero import rank_zero_info
Expand Down Expand Up @@ -369,10 +366,6 @@ def load_module_state_dict(self, module: Module, checkpoint: Mapping[str, Any])
self.module_to_device(module)
self._restore_zero_state(module, checkpoint)

def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any:
batch = apply_to_collection(batch, Tensor, function=_fp_to_half, precision=self.precision_plugin.precision)
return super().batch_to_device(batch, device)

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy")
Expand Down
14 changes: 2 additions & 12 deletions src/lightning_lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,13 @@ def module(self) -> nn.Module:
def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Casts all inputs to the right precision and handles autocast for operations in the module forward
method."""
precision = self._precision_plugin.precision
precision_to_type = {
"bf16": torch.bfloat16,
16: torch.float16,
32: torch.float32,
64: torch.float64,
}
# TODO: let the precision plugin handle the conversion
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
args, kwargs = apply_to_collection(
[args, kwargs], dtype=Tensor, function=_convert_fp_tensor, dst_type=precision_to_type[precision]
)
args, kwargs = apply_to_collection([args, kwargs], function=self._precision_plugin.convert_input, dtype=Tensor)

with self._precision_plugin.forward_context():
output = self._forward_module(*args, **kwargs)

output = apply_to_collection(
output, dtype=Tensor, function=_convert_fp_tensor, dst_type=torch.get_default_dtype()
output, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype()
)
return output

Expand Down
11 changes: 8 additions & 3 deletions src/pytorch_lightning/lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn

from lightning_lite.connector import _PLUGIN_INPUT as _LITE_PLUGIN_INPUT
from lightning_lite.connector import _PRECISION_INPUT
from lightning_lite.lite import LightningLite as _NewLightningLite
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.plugins import DeepSpeedPrecision as LiteDeepSpeedPrecision
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
strategy: Optional[Union[str, PLStrategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: Union[int, str] = 32,
precision: _PRECISION_INPUT = 32,
plugins: Optional[Union[_PL_PLUGIN_INPUT, List[_PL_PLUGIN_INPUT]]] = None,
gpus: Optional[Union[List[int], str, int]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None,
Expand Down Expand Up @@ -286,13 +287,17 @@ def _to_lite_precision_plugin(plugin: Optional[PLPrecisionPlugin]) -> LitePrecis
return LitePrecision()

if type(plugin) is PLNativeMixedPrecisionPlugin:
return LiteNativeMixedPrecision(precision=plugin.precision, device=plugin.device, scaler=plugin.scaler)
return LiteNativeMixedPrecision(
precision=plugin.precision, device=plugin.device, scaler=plugin.scaler # type: ignore[arg-type]
)

if type(plugin) is PLDoublePrecisionPlugin:
return LiteDoublePrecision()

if type(plugin) is PLDeepSpeedPrecisionPlugin:
return LiteDeepSpeedPrecision(precision=plugin.precision, amp_type=plugin.amp_type, amp_level=plugin.amp_level)
return LiteDeepSpeedPrecision(
precision=plugin.precision, amp_type=plugin.amp_type, amp_level=plugin.amp_level # type: ignore[arg-type]
)

if type(plugin) is PLTPUPrecisionPlugin:
return LiteTPUPrecision()
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

import pytorch_lightning as pl
from lightning_lite.plugins import ClusterEnvironment
from lightning_lite.plugins.precision.utils import _fp_to_half
from lightning_lite.utilities.enums import AMPType, PrecisionType
from lightning_lite.utilities.optimizer import optimizers_to_device
from lightning_lite.utilities.seed import reset_seed
Expand All @@ -41,6 +40,7 @@
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.utils import _fp_to_half
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@

import pytorch_lightning as pl
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.plugins.precision.utils import _fp_to_half
from lightning_lite.utilities.cloud_io import get_filesystem
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.strategies.utils import _fp_to_half
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls
Expand Down
12 changes: 12 additions & 0 deletions src/pytorch_lightning/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import os
from inspect import getmembers, isclass

import torch

from lightning_lite.plugins.precision.utils import _convert_fp_tensor
from lightning_lite.strategies import _StrategyRegistry
from lightning_lite.utilities.enums import PrecisionType
from lightning_lite.utilities.registry import _is_register_method_overridden
from pytorch_lightning.strategies.strategy import Strategy
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
Expand All @@ -34,3 +38,11 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) ->
for _, mod in getmembers(module, isclass):
if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"):
mod.register_strategies(registry)


def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if precision == PrecisionType.HALF:
return _convert_fp_tensor(tensor, torch.half)
if precision == PrecisionType.BFLOAT:
return _convert_fp_tensor(tensor, torch.bfloat16)
return tensor