Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass the scaler as an input to NativeMixedPrecisionPlugin #10055

Merged
merged 17 commits into from
Oct 28, 2021
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved the `optimizer_step` and `clip_gradients` hook from the `Accelerator` and `TrainingTypePlugin` into the `PrecisionPlugin` ([#10143](https://github.com/PyTorchLightning/pytorch-lightning/pull/10143), [#10029](https://github.com/PyTorchLightning/pytorch-lightning/pull/10029))


- `NativeMixedPrecisionPlugin` and its subclasses now take an optional `GradScaler` instance ([#10055](https://github.com/PyTorchLightning/pytorch-lightning/pull/10055))


- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))


Expand Down
2 changes: 1 addition & 1 deletion docs/source/extensions/accelerators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ One to handle differences from the training routine and one to handle different
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin, DDPPlugin

accelerator = GPUAccelerator(
precision_plugin=NativeMixedPrecisionPlugin(),
precision_plugin=NativeMixedPrecisionPlugin(16, "cuda"),
training_type_plugin=DDPPlugin(),
)
trainer = Trainer(accelerator=accelerator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@


class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
"""Mixed Precision for Full Sharded Training."""

precision = "mixed"
"""Native AMP for Fully Sharded Training."""

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html
Expand Down
60 changes: 30 additions & 30 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Union
from typing import Any, Callable, Dict, Generator, Optional, Union

import torch
from torch import Tensor
Expand All @@ -31,41 +31,39 @@


class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Plugin for native mixed precision training with :mod:`torch.cuda.amp`.
"""Plugin for Native Mixed Precision (AMP) training with ``torch.autocast``.

Args:
precision: Whether to use torch.float16 (16) or torch.bfloat16 (bf16).
precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``).
device: The device for ``torch.autocast``.
scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
"""

def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None:
super().__init__()
self.use_cpu = use_cpu
self._dtype = self._select_precision_dtype(precision)
self.backend = AMPType.NATIVE
if not self.is_bfloat16:
self.scaler = torch.cuda.amp.GradScaler()

def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtype:
if precision == "bf16":
if not _TORCH_GREATER_EQUAL_1_10:
raise MisconfigurationException(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
return torch.bfloat16
return torch.float16
backend = AMPType.NATIVE

@property
def is_bfloat16(self) -> bool:
return self._dtype == torch.bfloat16
def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
super().__init__()
if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
raise MisconfigurationException(
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
if scaler is None and precision == 16:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
scaler = torch.cuda.amp.GradScaler()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if scaler is not None and precision == "bf16":
raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
self.precision = precision
self.device = device
self.scaler = scaler

def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor) -> torch.Tensor:
if self.is_bfloat16:
return super().pre_backward(model, closure_loss)
closure_loss = self.scaler.scale(closure_loss)
if self.scaler is not None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
closure_loss = self.scaler.scale(closure_loss)
return super().pre_backward(model, closure_loss)

def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
if not self.is_bfloat16:
if self.scaler is not None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
tensor = self.scaler.scale(tensor)
super()._run_backward(tensor, model, *args, **kwargs)

Expand All @@ -77,7 +75,7 @@ def optimizer_step(
lambda_closure: Callable[[], Any],
**kwargs: Any,
) -> None:
if self.is_bfloat16:
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
if isinstance(optimizer, LBFGS):
Expand All @@ -98,7 +96,9 @@ def optimizer_step(

def autocast_context_manager(self) -> autocast:
if _TORCH_GREATER_EQUAL_1_10:
return autocast("cpu" if self.use_cpu else "cuda", dtype=self._dtype)
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return autocast()

@contextmanager
Expand All @@ -108,9 +108,9 @@ def forward_context(self) -> Generator[None, None, None]:
yield

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if "native_amp_scaling_state" in checkpoint and not self.is_bfloat16:
if self.scaler is not None and "native_amp_scaling_state" in checkpoint:
self.scaler.load_state_dict(checkpoint["native_amp_scaling_state"])

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if not self.is_bfloat16:
if self.scaler is not None:
checkpoint["native_amp_scaling_state"] = self.scaler.state_dict()
20 changes: 14 additions & 6 deletions pytorch_lightning/plugins/precision/sharded_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Optional, Union

import torch

from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler


class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Mixed Precision for Sharded Training."""
"""Native AMP for Sharded Training."""

def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None:
super().__init__(precision, use_cpu=use_cpu)
if not self.use_cpu:
self.scaler = ShardedGradScaler()
def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
if not _FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
"You have asked for sharded AMP but you have not installed it."
" Install `fairscale` using this guide: https://https://github.com/facebookresearch/fairscale"
)
super().__init__(precision, device, scaler=scaler or ShardedGradScaler())

def clip_grad_by_norm(self, optimizer: "OSS", clip_val: Union[int, float]) -> None:
optimizer.clip_grad_norm(clip_val)
36 changes: 17 additions & 19 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,16 +638,27 @@ def select_precision_plugin(self) -> PrecisionPlugin:
)
self.precision = "bf16"

if self.precision == 16:
rank_zero_info(f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)")
if self.precision in (16, "bf16"):
if self.precision == "bf16" and self.amp_type != AMPType.NATIVE:
raise MisconfigurationException(
f"You passed `Trainer(amp_type={self.amp_type.value!r}, precision='bf16')` but it's not supported."
" Try using `amp_type='native'` instead."
)

rank_zero_info(
f"Using 16bit {self.amp_type.value} Automatic Mixed Precision (AMP)"
if self.precision == 16
else "Using bfloat16 Automatic Mixed Precision (AMP)"
)

if self.amp_type == AMPType.NATIVE:
device = "cpu" if self.use_cpu else "cuda"

if self._is_sharded_training_type:
return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
return ShardedNativeMixedPrecisionPlugin(self.precision, device)
if self._is_fully_sharded_training_type:
return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)

return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
return FullyShardedNativeMixedPrecisionPlugin(self.precision, device)
return NativeMixedPrecisionPlugin(self.precision, device)

if self.amp_type == AMPType.APEX:
if self._is_sharded_training_type or self._is_fully_sharded_training_type:
Expand All @@ -657,19 +668,6 @@ def select_precision_plugin(self) -> PrecisionPlugin:
self.amp_level = self.amp_level or "O2"
return ApexMixedPrecisionPlugin(self.amp_level)

if self.precision == "bf16":
if self.amp_type != AMPType.NATIVE:
raise MisconfigurationException(
"You passed `Trainer(amp_type='apex', precision='bf16')` but it's not supported."
" Try using `amp_type='native'` instead."
)
rank_zero_info("Using bfloat16 Automatic Mixed Precision (AMP)")
if self._is_sharded_training_type:
return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
if self._is_fully_sharded_training_type:
return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)

raise RuntimeError("No precision set")

def select_training_type_plugin(self) -> TrainingTypePlugin:
Expand Down
22 changes: 11 additions & 11 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,41 @@


class AMPTestModel(BoringModel):
def _step(self, batch, batch_idx):
def _step(self, batch):
self._assert_autocast_enabled()
output = self(batch)
bfloat16 = self.trainer.precision_plugin.is_bfloat16
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
is_bfloat16 = self.trainer.precision_plugin.precision == "bf16"
assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16
loss = self.loss(batch, output)
return loss

def loss(self, batch, prediction):
# todo (sean): convert bfloat16 to float32 as mse loss for cpu amp is currently not supported
if self.trainer.precision_plugin.use_cpu:
if self.trainer.precision_plugin.device == "cpu":
prediction = prediction.float()
return super().loss(batch, prediction)

def training_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
output = self._step(batch)
return {"loss": output}

def validation_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
output = self._step(batch)
return {"x": output}

def test_step(self, batch, batch_idx):
output = self._step(batch, batch_idx)
output = self._step(batch)
return {"y": output}

def predict(self, batch, batch_idx, dataloader_idx=None):
def predict_step(self, batch, batch_idx, dataloader_idx=None):
self._assert_autocast_enabled()
output = self(batch)
bfloat16 = self.trainer.precision_plugin.is_bfloat16
assert output.dtype == torch.float16 if not bfloat16 else torch.bfloat16
is_bfloat16 = self.trainer.precision_plugin.precision == "bf16"
assert output.dtype == torch.float16 if not is_bfloat16 else torch.bfloat16
return output

def _assert_autocast_enabled(self):
if self.trainer.precision_plugin.use_cpu:
if self.trainer.precision_plugin.device == "cpu":
assert torch.is_autocast_cpu_enabled()
else:
assert torch.is_autocast_enabled()
Expand Down
28 changes: 13 additions & 15 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
Expand All @@ -47,7 +46,7 @@ class MyApexPlugin(ApexMixedPrecisionPlugin):
},
)
@mock.patch("torch.cuda.device_count", return_value=2)
@pytest.mark.parametrize("ddp_backend,gpus", [("ddp", 2), ("ddp2", 2), ("ddp_spawn", 2)])
@pytest.mark.parametrize("strategy,gpus", [("ddp", 2), ("ddp2", 2), ("ddp_spawn", 2)])
@pytest.mark.parametrize(
"amp,custom_plugin,plugin_cls",
[
Expand All @@ -57,21 +56,19 @@ class MyApexPlugin(ApexMixedPrecisionPlugin):
pytest.param("apex", True, MyApexPlugin, marks=RunIf(amp_apex=True)),
],
)
def test_amp_apex_ddp(
mocked_device_count, ddp_backend: str, gpus: int, amp: str, custom_plugin: bool, plugin_cls: MixedPrecisionPlugin
):

def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, plugin_cls):
plugin = None
if custom_plugin:
plugin = plugin_cls(16, "cpu") if amp == "native" else plugin_cls()
trainer = Trainer(
fast_dev_run=True,
precision=16,
amp_backend=amp,
gpus=gpus,
strategy=ddp_backend,
plugins=[plugin_cls()] if custom_plugin else None,
strategy=strategy,
plugins=plugin,
)
assert isinstance(trainer.precision_plugin, plugin_cls)
if amp == "native":
assert not trainer.precision_plugin.is_bfloat16


class GradientUnscaleBoringModel(BoringModel):
Expand Down Expand Up @@ -179,13 +176,14 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):

@RunIf(min_torch="1.10")
def test_cpu_amp_precision_context_manager(tmpdir):
"""Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set."""
plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True)
assert plugin.use_cpu
assert not hasattr(plugin, "scaler")
"""Test to ensure that the context manager correctly is set to CPU + bfloat16."""
plugin = NativeMixedPrecisionPlugin("bf16", "cpu")
assert plugin.device == "cpu"
assert plugin.scaler is None
context_manager = plugin.autocast_context_manager()
assert isinstance(context_manager, torch.autocast)
assert context_manager.fast_dtype == torch.bfloat16
# check with str due to a bug upstream: https://github.com/pytorch/pytorch/issues/65786
assert str(context_manager.fast_dtype) == str(torch.bfloat16)


def test_precision_selection_raises(monkeypatch):
Expand Down