From 74003355730f121e9a9405f5845ec68b82f31084 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 20 Sep 2022 00:31:14 +0200 Subject: [PATCH 1/9] primitives for input/output dtype conversion --- .../plugins/precision/deepspeed.py | 17 ++++++++---- .../plugins/precision/double.py | 4 +++ .../plugins/precision/native_amp.py | 8 +++++- .../plugins/precision/precision.py | 13 +++++++++ .../plugins/precision/tpu_bf16.py | 6 +++++ src/lightning_lite/plugins/precision/utils.py | 27 ------------------- src/lightning_lite/strategies/deepspeed.py | 7 ----- src/pytorch_lightning/lite/wrappers.py | 18 ++----------- src/pytorch_lightning/strategies/deepspeed.py | 2 +- src/pytorch_lightning/strategies/ipu.py | 2 +- src/pytorch_lightning/strategies/utils.py | 13 +++++++++ 11 files changed, 59 insertions(+), 58 deletions(-) delete mode 100644 src/lightning_lite/plugins/precision/utils.py diff --git a/src/lightning_lite/plugins/precision/deepspeed.py b/src/lightning_lite/plugins/precision/deepspeed.py index 8610121863195..4ba949ff721f0 100644 --- a/src/lightning_lite/plugins/precision/deepspeed.py +++ b/src/lightning_lite/plugins/precision/deepspeed.py @@ -13,9 +13,11 @@ # limitations under the License. from typing import Any, Optional, TYPE_CHECKING, Union +import torch from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.optim import LBFGS, Optimizer +from typing_extensions import Literal from lightning_lite.plugins.precision.precision import Precision from lightning_lite.utilities.enums import AMPType, PrecisionType @@ -30,7 +32,7 @@ class DeepSpeedPrecision(Precision): """Precision plugin for DeepSpeed integration. Args: - precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). + precision: Full precision (32), half precision (16) or bfloat16 precision (bf16). amp_type: The mixed precision backend to use ("native" or "apex"). amp_level: The optimization level to use (O1, O2, etc...). By default it will be set to "O2" if ``amp_type`` is set to "apex". @@ -43,7 +45,7 @@ class DeepSpeedPrecision(Precision): If unsupported ``precision`` is provided. """ - def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None: + def __init__(self, precision: Literal[16, 32, "bf16"], amp_type: str, amp_level: Optional[str] = None) -> None: if amp_type == AMPType.APEX: if not _APEX_AVAILABLE: raise ImportError( @@ -53,11 +55,11 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona amp_level = amp_level or "O2" - supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT) - if precision not in supported_precision: + supported_precision = ("16", "32", "bf16") + if str(precision) not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in DeepSpeed." - f" `precision` must be one of: {(x.value for x in supported_precision)}." + f" `precision` must be one of: {', '.join(supported_precision)}." ) super().__init__() @@ -65,6 +67,11 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona self.amp_type = amp_type self.amp_level = amp_level + def convert_input(self, data: Tensor) -> Tensor: + precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16, 32: torch.float32} + to_type = precision_to_type[self.precision] + return data.to(to_type) if torch.is_floating_point(data) else data + def backward(self, tensor: Tensor, model: Optional["deepspeed.DeepSpeedEngine"], *args: Any, **kwargs: Any) -> None: """Performs back-propagation using DeepSpeed's engine.""" if model is None: diff --git a/src/lightning_lite/plugins/precision/double.py b/src/lightning_lite/plugins/precision/double.py index 13f5909deac9d..8a2437b73eb3e 100644 --- a/src/lightning_lite/plugins/precision/double.py +++ b/src/lightning_lite/plugins/precision/double.py @@ -15,6 +15,7 @@ from typing import Generator import torch +from torch import Tensor from lightning_lite.plugins.precision import Precision @@ -34,3 +35,6 @@ def forward_context(self) -> Generator[None, None, None]: torch.set_default_dtype(torch.float64) yield torch.set_default_dtype(default_dtype) + + def convert_input(self, data: Tensor) -> Tensor: + return data.to(torch.float64) if torch.is_floating_point(data) else data diff --git a/src/lightning_lite/plugins/precision/native_amp.py b/src/lightning_lite/plugins/precision/native_amp.py index bd54cdb846299..46c33ca47f619 100644 --- a/src/lightning_lite/plugins/precision/native_amp.py +++ b/src/lightning_lite/plugins/precision/native_amp.py @@ -18,6 +18,7 @@ from torch import Tensor from torch.nn import Module from torch.optim import LBFGS, Optimizer +from typing_extensions import Literal from lightning_lite.plugins.precision.mixed import MixedPrecision from lightning_lite.utilities.enums import AMPType @@ -41,7 +42,7 @@ class NativeMixedPrecision(MixedPrecision): backend = AMPType.NATIVE def __init__( - self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + self, precision: Literal[16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: super().__init__() if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10: @@ -59,6 +60,11 @@ def forward_context(self) -> Generator[None, None, None]: with self._autocast_context_manager(): yield + def convert_input(self, data: Tensor) -> Tensor: + precision_to_type = {"bf16": torch.bfloat16, 16: torch.float16} + to_type = precision_to_type[self.precision] + return data.to(to_type) if torch.is_floating_point(data) else data + def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: if self.scaler is not None: tensor = self.scaler.scale(tensor) diff --git a/src/lightning_lite/plugins/precision/precision.py b/src/lightning_lite/plugins/precision/precision.py index 881db89e88f56..14a043bd0a37d 100644 --- a/src/lightning_lite/plugins/precision/precision.py +++ b/src/lightning_lite/plugins/precision/precision.py @@ -14,6 +14,7 @@ import contextlib from typing import Any, Dict, Generator, Optional, Union +import torch from torch import Tensor from torch.nn import Module from torch.optim import Optimizer @@ -34,6 +35,18 @@ def forward_context(self) -> Generator[None, None, None]: """A contextmanager for managing model forward/training_step/evaluation_step/predict_step.""" yield + def convert_input(self, data: Tensor) -> Tensor: + """Convert model inputs (forward) to the floating point precision type of this plugin. + This is a no-op for tensors that are not of floating-point type or already have the desired type. + """ + return data.to(torch.float32) if torch.is_floating_point(data) else data + + def convert_output(self, data: Tensor) -> Tensor: + """Convert model outputs (forward) back to the default floating point precision type. + This is a no-op for tensors that are not of floating-point type or already have the desired type. + """ + return data.to(torch.get_default_dtype()) if torch.is_floating_point(data) else data + def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> None: """Runs before precision plugin executes backward. diff --git a/src/lightning_lite/plugins/precision/tpu_bf16.py b/src/lightning_lite/plugins/precision/tpu_bf16.py index d388a9ae175ac..d5a369a6cf45c 100644 --- a/src/lightning_lite/plugins/precision/tpu_bf16.py +++ b/src/lightning_lite/plugins/precision/tpu_bf16.py @@ -13,6 +13,9 @@ # limitations under the License. import os +import torch +from torch import Tensor + from lightning_lite.plugins.precision import TPUPrecision @@ -25,5 +28,8 @@ def __init__(self) -> None: super().__init__() os.environ["XLA_USE_BF16"] = "1" + def convert_input(self, data: Tensor) -> Tensor: + return data.to(torch.bfloat16) if torch.is_floating_point(data) else data + def teardown(self) -> None: os.environ.pop("XLA_USE_BF16", None) diff --git a/src/lightning_lite/plugins/precision/utils.py b/src/lightning_lite/plugins/precision/utils.py deleted file mode 100644 index f9af7de5baf75..0000000000000 --- a/src/lightning_lite/plugins/precision/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from lightning_lite.utilities.enums import PrecisionType - - -def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor: - if torch.is_floating_point(tensor): - if precision == PrecisionType.HALF: - return tensor.half() - if precision == PrecisionType.BFLOAT: - return tensor.bfloat16() - - return tensor diff --git a/src/lightning_lite/strategies/deepspeed.py b/src/lightning_lite/strategies/deepspeed.py index 985989752f771..f3d46616728ba 100644 --- a/src/lightning_lite/strategies/deepspeed.py +++ b/src/lightning_lite/strategies/deepspeed.py @@ -23,16 +23,13 @@ import torch from lightning_utilities.core.imports import RequirementCache from lightning_utilities.core.rank_zero import rank_zero_only -from torch import Tensor from torch.nn import Module from torch.optim import Optimizer from lightning_lite.accelerators import Accelerator from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment from lightning_lite.plugins.precision import Precision -from lightning_lite.plugins.precision.utils import _fp_to_half from lightning_lite.strategies.ddp import DDPStrategy -from lightning_lite.utilities.apply_func import apply_to_collection from lightning_lite.utilities.distributed import log from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.rank_zero import rank_zero_info @@ -369,10 +366,6 @@ def load_module_state_dict(self, module: Module, checkpoint: Mapping[str, Any]) self.module_to_device(module) self._restore_zero_state(module, checkpoint) - def batch_to_device(self, batch: Any, device: Optional[torch.device] = None) -> Any: - batch = apply_to_collection(batch, Tensor, function=_fp_to_half, precision=self.precision_plugin.precision) - return super().batch_to_device(batch, device) - @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register("deepspeed", cls, description="Default DeepSpeed Strategy") diff --git a/src/pytorch_lightning/lite/wrappers.py b/src/pytorch_lightning/lite/wrappers.py index 7b7a304413476..6801adc752fcd 100644 --- a/src/pytorch_lightning/lite/wrappers.py +++ b/src/pytorch_lightning/lite/wrappers.py @@ -91,26 +91,12 @@ def module(self) -> nn.Module: def forward(self, *args: Any, **kwargs: Any) -> Any: """Casts all inputs to the right precision and handles autocast for operations in the module forward method.""" - precision = self._precision_plugin.precision - precision_to_type = { - "bf16": torch.bfloat16, - 16: torch.float16, - 32: torch.float32, - 64: torch.float64, - } - # TODO: let the precision plugin handle the conversion - to_type = precision_to_type[precision] - - def _convert_float_tensor(t: Tensor) -> Tensor: - return t.to(to_type) if torch.is_floating_point(t) else t - - args, kwargs = apply_to_collection([args, kwargs], function=_convert_float_tensor, dtype=Tensor) + args, kwargs = apply_to_collection([args, kwargs], function=self._precision_plugin.convert_input, dtype=Tensor) with self._precision_plugin.forward_context(): output = self._forward_module(*args, **kwargs) - to_type = torch.get_default_dtype() - output = apply_to_collection(output, function=_convert_float_tensor, dtype=Tensor) + output = apply_to_collection(output, function=self._precision_plugin.convert_output, dtype=Tensor) return output @overload diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 2596a6fa19666..23bf850b568bd 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -31,7 +31,6 @@ import pytorch_lightning as pl from lightning_lite.plugins import ClusterEnvironment -from lightning_lite.plugins.precision.utils import _fp_to_half from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.optimizer import optimizers_to_device from lightning_lite.utilities.seed import reset_seed @@ -41,6 +40,7 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy +from pytorch_lightning.strategies.utils import _fp_to_half from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index a363d143e523e..ed28546b5fd56 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -22,13 +22,13 @@ import pytorch_lightning as pl from lightning_lite.plugins import CheckpointIO, ClusterEnvironment -from lightning_lite.plugins.precision.utils import _fp_to_half from lightning_lite.utilities.cloud_io import get_filesystem from lightning_lite.utilities.enums import PrecisionType from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast +from pytorch_lightning.strategies.utils import _fp_to_half from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities import _IPU_AVAILABLE, _POPTORCH_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.data import _get_dataloader_init_args_and_kwargs, _reinstantiate_wrapped_cls diff --git a/src/pytorch_lightning/strategies/utils.py b/src/pytorch_lightning/strategies/utils.py index fa360d3770cd7..d4941f3fced46 100644 --- a/src/pytorch_lightning/strategies/utils.py +++ b/src/pytorch_lightning/strategies/utils.py @@ -15,7 +15,10 @@ import os from inspect import getmembers, isclass +import torch + from lightning_lite.strategies import _StrategyRegistry +from lightning_lite.utilities.enums import PrecisionType from lightning_lite.utilities.registry import _is_register_method_overridden from pytorch_lightning.strategies.strategy import Strategy from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation @@ -34,3 +37,13 @@ def _call_register_strategies(registry: _StrategyRegistry, base_module: str) -> for _, mod in getmembers(module, isclass): if issubclass(mod, Strategy) and _is_register_method_overridden(mod, Strategy, "register_strategies"): mod.register_strategies(registry) + + +def _fp_to_half(tensor: torch.Tensor, precision: PrecisionType) -> torch.Tensor: + if torch.is_floating_point(tensor): + if precision == PrecisionType.HALF: + return tensor.half() + if precision == PrecisionType.BFLOAT: + return tensor.bfloat16() + + return tensor From 0692a416ea9579872b379791206e7312139bb506 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Sep 2022 22:34:25 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning_lite/plugins/precision/precision.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning_lite/plugins/precision/precision.py b/src/lightning_lite/plugins/precision/precision.py index 14a043bd0a37d..3a4d5eaae0225 100644 --- a/src/lightning_lite/plugins/precision/precision.py +++ b/src/lightning_lite/plugins/precision/precision.py @@ -37,12 +37,14 @@ def forward_context(self) -> Generator[None, None, None]: def convert_input(self, data: Tensor) -> Tensor: """Convert model inputs (forward) to the floating point precision type of this plugin. + This is a no-op for tensors that are not of floating-point type or already have the desired type. """ return data.to(torch.float32) if torch.is_floating_point(data) else data def convert_output(self, data: Tensor) -> Tensor: """Convert model outputs (forward) back to the default floating point precision type. + This is a no-op for tensors that are not of floating-point type or already have the desired type. """ return data.to(torch.get_default_dtype()) if torch.is_floating_point(data) else data From 47ff2459ee05fb485551340f8fb4af504b94e90c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 22 Sep 2022 02:17:55 +0200 Subject: [PATCH 3/9] Convert output too --- src/lightning_lite/plugins/precision/precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/plugins/precision/precision.py b/src/lightning_lite/plugins/precision/precision.py index 8ee923b356480..61be46928fff6 100644 --- a/src/lightning_lite/plugins/precision/precision.py +++ b/src/lightning_lite/plugins/precision/precision.py @@ -48,7 +48,7 @@ def convert_output(self, data: Tensor) -> Tensor: This is a no-op for tensors that are not of floating-point type or already have the desired type. """ - return data.to(torch.get_default_dtype()) if torch.is_floating_point(data) else data + return _convert_fp_tensor(data, torch.get_default_dtype()) def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> None: """Runs before precision plugin executes backward. From 0c6b99afd1e601a21d17287c211378793946dc71 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 22 Sep 2022 15:57:12 +0200 Subject: [PATCH 4/9] typing updates --- src/lightning_lite/connector.py | 8 +++++--- src/lightning_lite/lite.py | 4 ++-- src/pytorch_lightning/lite/lite.py | 11 ++++++++--- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index 7b8d4a9330df8..cafc1b024c9ce 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -16,6 +16,7 @@ from typing import Dict, List, Optional, Union import torch +from typing_extensions import Literal from lightning_lite.accelerators import ACCELERATOR_REGISTRY from lightning_lite.accelerators.accelerator import Accelerator @@ -58,6 +59,7 @@ _PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO] _PLUGIN_INPUT = Union[_PLUGIN, str] +_PRECISION_INPUT = Literal[16, 32, 64, 16, "bf16"] class _Connector: @@ -97,7 +99,7 @@ def __init__( strategy: Optional[Union[str, Strategy]] = None, devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, - precision: Union[int, str] = 32, + precision: _PRECISION_INPUT = 32, plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, ) -> None: # 1. Parsing flags @@ -111,7 +113,7 @@ def __init__( # For devices: Assign gpus, ipus, etc. to the accelerator flag and devices flag self._strategy_flag: Optional[Union[Strategy, str]] = None self._accelerator_flag: Optional[Union[Accelerator, str]] = None - self._precision_flag: Optional[Union[int, str]] = None + self._precision_flag: Optional[_PRECISION_INPUT] = None self._precision_plugin_flag: Optional[Precision] = None self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None self._parallel_devices: List[Union[int, torch.device, str]] = [] @@ -154,7 +156,7 @@ def _check_config_and_set_final_flags( self, strategy: Optional[Union[str, Strategy]], accelerator: Optional[Union[str, Accelerator]], - precision: Union[int, str], + precision: _PRECISION_INPUT, plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]], ) -> None: """This method checks: diff --git a/src/lightning_lite/lite.py b/src/lightning_lite/lite.py index 4a1f4e8004ba1..5af7c39d1698e 100644 --- a/src/lightning_lite/lite.py +++ b/src/lightning_lite/lite.py @@ -27,7 +27,7 @@ from torch.utils.data import BatchSampler, DataLoader, DistributedSampler from lightning_lite.accelerators.accelerator import Accelerator -from lightning_lite.connector import _Connector, _PLUGIN_INPUT +from lightning_lite.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT from lightning_lite.plugins import Precision from lightning_lite.strategies import DeepSpeedStrategy, Strategy, XLAStrategy from lightning_lite.strategies.strategy import TBroadcast @@ -74,7 +74,7 @@ def __init__( strategy: Optional[Union[str, Strategy]] = None, devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, - precision: Union[int, str] = 32, + precision: _PRECISION_INPUT = 32, plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, ) -> None: self._connector = _Connector( diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index c34561b702a25..be6f2108249d0 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -18,6 +18,7 @@ from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn from lightning_lite.connector import _PLUGIN_INPUT as _LITE_PLUGIN_INPUT +from lightning_lite.connector import _PRECISION_INPUT from lightning_lite.lite import LightningLite as _NewLightningLite from lightning_lite.plugins import CheckpointIO, ClusterEnvironment from lightning_lite.plugins import DeepSpeedPrecision as LiteDeepSpeedPrecision @@ -98,7 +99,7 @@ def __init__( strategy: Optional[Union[str, PLStrategy]] = None, devices: Optional[Union[List[int], str, int]] = None, num_nodes: int = 1, - precision: Union[int, str] = 32, + precision: _PRECISION_INPUT = 32, plugins: Optional[Union[_PL_PLUGIN_INPUT, List[_PL_PLUGIN_INPUT]]] = None, gpus: Optional[Union[List[int], str, int]] = None, tpu_cores: Optional[Union[List[int], str, int]] = None, @@ -286,13 +287,17 @@ def _to_lite_precision_plugin(plugin: Optional[PLPrecisionPlugin]) -> LitePrecis return LitePrecision() if type(plugin) is PLNativeMixedPrecisionPlugin: - return LiteNativeMixedPrecision(precision=plugin.precision, device=plugin.device, scaler=plugin.scaler) + return LiteNativeMixedPrecision( + precision=plugin.precision, device=plugin.device, scaler=plugin.scaler # type: ignore[arg-type] + ) if type(plugin) is PLDoublePrecisionPlugin: return LiteDoublePrecision() if type(plugin) is PLDeepSpeedPrecisionPlugin: - return LiteDeepSpeedPrecision(precision=plugin.precision, amp_type=plugin.amp_type, amp_level=plugin.amp_level) + return LiteDeepSpeedPrecision( + precision=plugin.precision, amp_type=plugin.amp_type, amp_level=plugin.amp_level # type: ignore[arg-type] + ) if type(plugin) is PLTPUPrecisionPlugin: return LiteTPUPrecision() From 304edf4b867f8e9a0306f081b75d1503a3a92773 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 22 Sep 2022 10:11:05 -0400 Subject: [PATCH 5/9] Update src/lightning_lite/connector.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- src/lightning_lite/connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning_lite/connector.py b/src/lightning_lite/connector.py index cafc1b024c9ce..3e9a7560d6472 100644 --- a/src/lightning_lite/connector.py +++ b/src/lightning_lite/connector.py @@ -59,7 +59,7 @@ _PLUGIN = Union[Precision, ClusterEnvironment, CheckpointIO] _PLUGIN_INPUT = Union[_PLUGIN, str] -_PRECISION_INPUT = Literal[16, 32, 64, 16, "bf16"] +_PRECISION_INPUT = Literal[16, 32, 64, "bf16"] class _Connector: From f313ee6209bfd5b0e05052f7fd7634405db3cc50 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 23 Sep 2022 01:35:01 +0200 Subject: [PATCH 6/9] remove convert_output method --- src/lightning_lite/plugins/precision/precision.py | 7 ------- src/lightning_lite/wrappers.py | 5 ++++- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/lightning_lite/plugins/precision/precision.py b/src/lightning_lite/plugins/precision/precision.py index b2b73382ae055..cf2275039a80c 100644 --- a/src/lightning_lite/plugins/precision/precision.py +++ b/src/lightning_lite/plugins/precision/precision.py @@ -43,13 +43,6 @@ def convert_input(self, data: Tensor) -> Tensor: """ return _convert_fp_tensor(data, torch.float32) - def convert_output(self, data: Tensor) -> Tensor: - """Convert model outputs (forward) back to the default floating point precision type. - - This is a no-op for tensors that are not of floating-point type or already have the desired type. - """ - return _convert_fp_tensor(data, torch.get_default_dtype()) - def pre_backward(self, tensor: Tensor, module: Optional[Module]) -> Any: """Runs before precision plugin executes backward. diff --git a/src/lightning_lite/wrappers.py b/src/lightning_lite/wrappers.py index 8e9588843b7a1..c13f9a450b054 100644 --- a/src/lightning_lite/wrappers.py +++ b/src/lightning_lite/wrappers.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader from lightning_lite.plugins import Precision +from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.strategies import Strategy from lightning_lite.utilities import move_data_to_device from lightning_lite.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin @@ -101,7 +102,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: with self._precision_plugin.forward_context(): output = self._forward_module(*args, **kwargs) - output = apply_to_collection(output, function=self._precision_plugin.convert_output, dtype=Tensor) + output = apply_to_collection( + output, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype() + ) return output @overload From 3e0eb2684359cd0d56682bd749abc71bdd4224cc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 28 Sep 2022 00:34:26 +0200 Subject: [PATCH 7/9] revert --- src/lightning_lite/plugins/precision/deepspeed.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning_lite/plugins/precision/deepspeed.py b/src/lightning_lite/plugins/precision/deepspeed.py index 79f0c9ed757ed..0e72f0c555230 100644 --- a/src/lightning_lite/plugins/precision/deepspeed.py +++ b/src/lightning_lite/plugins/precision/deepspeed.py @@ -19,7 +19,7 @@ from typing_extensions import Literal from lightning_lite.plugins.precision.precision import Precision -from lightning_lite.plugins.precision.utils import _convert_fp_tensor +from lightning_lite.utilities.enums import AMPType, PrecisionType from lightning_lite.utilities.enums import AMPType from lightning_lite.utilities.imports import _APEX_AVAILABLE from lightning_lite.utilities.types import Steppable @@ -56,11 +56,11 @@ def __init__(self, precision: Literal[16, 32, "bf16"], amp_type: str, amp_level: amp_level = amp_level or "O2" - supported_precision = ("16", "32", "bf16") - if str(precision) not in supported_precision: + supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT) + if precision not in supported_precision: raise ValueError( f"`precision={precision!r})` is not supported in DeepSpeed." - f" `precision` must be one of: {', '.join(supported_precision)}." + f" `precision` must be one of: {(x.value for x in supported_precision)}." ) super().__init__() From f4d0b33e1f3b0d4e9ac1b20fafb9a69d17d8c0ec Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 28 Sep 2022 00:35:28 +0200 Subject: [PATCH 8/9] fix imports --- src/lightning_lite/plugins/precision/deepspeed.py | 2 +- src/lightning_lite/plugins/precision/double.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning_lite/plugins/precision/deepspeed.py b/src/lightning_lite/plugins/precision/deepspeed.py index 0e72f0c555230..3817bf3aa3f04 100644 --- a/src/lightning_lite/plugins/precision/deepspeed.py +++ b/src/lightning_lite/plugins/precision/deepspeed.py @@ -19,8 +19,8 @@ from typing_extensions import Literal from lightning_lite.plugins.precision.precision import Precision +from lightning_lite.plugins.precision.utils import _convert_fp_tensor from lightning_lite.utilities.enums import AMPType, PrecisionType -from lightning_lite.utilities.enums import AMPType from lightning_lite.utilities.imports import _APEX_AVAILABLE from lightning_lite.utilities.types import Steppable diff --git a/src/lightning_lite/plugins/precision/double.py b/src/lightning_lite/plugins/precision/double.py index 5d0bb86fc3be9..dd0aa73eee5f0 100644 --- a/src/lightning_lite/plugins/precision/double.py +++ b/src/lightning_lite/plugins/precision/double.py @@ -15,8 +15,8 @@ from typing import Generator import torch -from torch.nn import Module from torch import Tensor +from torch.nn import Module from lightning_lite.plugins.precision.precision import Precision from lightning_lite.plugins.precision.utils import _convert_fp_tensor From 461d18f92e1a10b48b4fbcc312889ba0b8ed3be6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 29 Sep 2022 22:58:40 +0200 Subject: [PATCH 9/9] bf16 fix --- .../tests_lite/plugins/precision/test_native_amp_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_lite/plugins/precision/test_native_amp_integration.py b/tests/tests_lite/plugins/precision/test_native_amp_integration.py index f657236069342..94d8c399679cf 100644 --- a/tests/tests_lite/plugins/precision/test_native_amp_integration.py +++ b/tests/tests_lite/plugins/precision/test_native_amp_integration.py @@ -67,6 +67,6 @@ def after_backward(self, model): ], ) def test_native_mixed_precision(accelerator, precision, expected_dtype): - lite = NativeMixedPrecisionBoringLite(accelerator=accelerator, precision=16) + lite = NativeMixedPrecisionBoringLite(accelerator=accelerator, precision=precision) lite.expected_dtype = expected_dtype lite.run()