Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default to precision=bf16 on CPU when precision=16 is passed #10033

Merged
merged 13 commits into from
Oct 20, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PARAMETERS

if _APEX_AVAILABLE:
Expand All @@ -30,6 +31,11 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

def __init__(self, amp_level: str = "O2") -> None:
if not _APEX_AVAILABLE:
raise MisconfigurationException(
"You have asked for Apex AMP but you have not installed it."
" Install `apex` using this guide: https://github.com/NVIDIA/apex"
)
super().__init__()
self.backend = AMPType.APEX
self.amp_level = amp_level
Expand Down
13 changes: 1 addition & 12 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE, _TORCH_CPU_AMP_AVAILABLE, AMPType
from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand All @@ -34,13 +34,6 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):

def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> None:
super().__init__()

if use_cpu and not _TORCH_CPU_AMP_AVAILABLE:
raise MisconfigurationException(
"You have asked for native AMP on CPU, but AMP is only available on GPU for PyTorch 1.9 "
"and lower. To use native AMP on CPU, install PyTorch 1.10 or later."
)

self.use_cpu = use_cpu
self._dtype = self._select_precision_dtype(precision)
self.backend = AMPType.NATIVE
Expand All @@ -54,10 +47,6 @@ def _select_precision_dtype(self, precision: Union[int, str] = 16) -> torch.dtyp
"To use bfloat16 with native amp you must install torch greater or equal to 1.10."
)
return torch.bfloat16
elif self.use_cpu:
raise MisconfigurationException(
"CPU native amp only supports bfloat16. Please pass precision='bf16' to the Trainer."
)
return torch.float16

@property
Expand Down
47 changes: 35 additions & 12 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import (
_APEX_AVAILABLE,
_HOROVOD_AVAILABLE,
_IPU_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
Expand Down Expand Up @@ -624,9 +623,26 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return PrecisionPlugin()
if self.precision == 64:
return DoublePrecisionPlugin()
if self.precision in (16, "bf16"):

# maybe convert the precision value
if self.precision == 16 and self.use_cpu:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we support Trainer(precision="16"), do we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so because we don't convert the value to a PrecisionType (which works for int and str)

IMO we should, but should be done in a follow-up. The previous code here had the same check

if self.amp_type == AMPType.APEX:
# apex was explicitly passed, not a good idea to silently switch to native AMP
raise MisconfigurationException(
"You passed `Trainer(accelerator='cpu', precision=16, amp_type='apex')`"
" but apex AMP not supported on CPU."
)
# this automatic switch is to ease transition between accelerator environments
rank_zero_warn(
"You passed `Trainer(accelerator='cpu', precision=16)` but native AMP is not supported on CPU."
" Using `precision='bf16'` instead."
)
self.precision = "bf16"

if self.precision == 16:
rank_zero_info(f"Using 16bit {self.amp_type.value} Native Mixed Precision (AMP)")
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if self.amp_type == AMPType.NATIVE:
log.info(f"Using native {self.precision} bit Automatic Mixed Precision")
if self._is_sharded_training_type:
return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
if self._is_fully_sharded_training_type:
Expand All @@ -635,21 +651,28 @@ def select_precision_plugin(self) -> PrecisionPlugin:
return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)

if self.amp_type == AMPType.APEX:
if not _APEX_AVAILABLE:
raise MisconfigurationException(
"You have asked for Apex AMP but you have not installed it yet."
" Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
)
if self._is_sharded_training_type or self._is_fully_sharded_training_type:
raise MisconfigurationException(
"Sharded Plugin is not supported with Apex AMP, please using native AMP for 16-bit precision."
"Sharded plugins are not supported with apex, please switch to `amp_backend='native'`."
)
log.info("Using APEX 16bit precision.")

self.amp_level = self.amp_level or "O2"

return ApexMixedPrecisionPlugin(self.amp_level)

if self.precision == "bf16":
if self.amp_type != AMPType.NATIVE:
raise MisconfigurationException(
"You passed `Trainer(amp_type='apex', precision='bf16')` but it's not supported."
" Try using `amp_type='native'` instead."
)
rank_zero_info("Using bfloat16 precision")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if self._is_sharded_training_type:
return ShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
if self._is_fully_sharded_training_type:
return FullyShardedNativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
return NativeMixedPrecisionPlugin(self.precision, use_cpu=self.use_cpu)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

raise RuntimeError("No precision set")

def select_training_type_plugin(self) -> TrainingTypePlugin:
if (
isinstance(self.distributed_backend, Accelerator)
Expand Down
31 changes: 6 additions & 25 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities import _TORCH_BFLOAT_AVAILABLE, _TORCH_CPU_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_DEV_1_10
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -69,7 +69,7 @@ def _assert_autocast_enabled(self):
assert torch.is_autocast_enabled()


@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="CPU AMP not available")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Needs bfloat16 support")
@pytest.mark.parametrize(
"strategy",
[
Expand All @@ -78,13 +78,7 @@ def _assert_autocast_enabled(self):
"ddp_spawn",
],
)
@pytest.mark.parametrize(
"precision",
[
pytest.param(16, marks=pytest.mark.skip("CPU precision 16 is not supported in PyTorch yet.")), # TODO
"bf16",
],
)
@pytest.mark.parametrize("precision", [16, "bf16"])
@pytest.mark.parametrize("num_processes", [1, 2])
def test_amp_cpus(tmpdir, strategy, precision, num_processes):
"""Make sure combinations of AMP and training types work if supported."""
Expand All @@ -95,7 +89,6 @@ def test_amp_cpus(tmpdir, strategy, precision, num_processes):
)

model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
trainer.test(model)
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
Expand All @@ -104,20 +97,9 @@ def test_amp_cpus(tmpdir, strategy, precision, num_processes):


@RunIf(min_gpus=2)
@pytest.mark.parametrize(
"strategy",
[None, "dp", "ddp_spawn"],
)
@pytest.mark.parametrize(
"precision",
[
16,
pytest.param(
"bf16",
marks=pytest.mark.skipif(not _TORCH_BFLOAT_AVAILABLE, reason="torch.bfloat16 not available"),
),
],
)
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_DEV_1_10, reason="Needs bfloat16 support")
@pytest.mark.parametrize("strategy", [None, "dp", "ddp_spawn"])
@pytest.mark.parametrize("precision", [16, "bf16"])
@pytest.mark.parametrize("gpus", [1, 2])
def test_amp_gpus(tmpdir, strategy, precision, gpus):
"""Make sure combinations of AMP and training types work if supported."""
Expand All @@ -126,7 +108,6 @@ def test_amp_gpus(tmpdir, strategy, precision, gpus):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, gpus=gpus, strategy=strategy, precision=precision)

model = AMPTestModel()
# tutils.run_model_test(trainer_options, model)
trainer.fit(model)
trainer.test(model)
trainer.predict(model, DataLoader(RandomDataset(32, 64)))
Expand Down
59 changes: 27 additions & 32 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,28 +178,6 @@ def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
trainer.fit(model)


@RunIf(min_gpus=1, max_torch="1.9")
def test_amp_precision_16_bfloat_throws_error(tmpdir):
with pytest.raises(
MisconfigurationException,
match="To use bfloat16 with native amp you must install torch greater or equal to 1.10",
):
Trainer(
default_root_dir=tmpdir,
precision="bf16",
gpus=1,
)


@RunIf(max_torch="1.9")
def test_cpu_amp_precision_throws_error(tmpdir):
with pytest.raises(
MisconfigurationException,
match="To use native AMP on CPU, install PyTorch 1.10 or later.",
):
NativeMixedPrecisionPlugin(use_cpu=True)


@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="Torch CPU AMP is not available.")
def test_cpu_amp_precision_context_manager(tmpdir):
"""Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set."""
Expand All @@ -212,15 +190,32 @@ def test_cpu_amp_precision_context_manager(tmpdir):
assert context_manager.fast_dtype == torch.bfloat16


@pytest.mark.skipif(not _TORCH_CPU_AMP_AVAILABLE, reason="Torch CPU AMP is not available.")
def test_cpu_amp_precision_16_throws_error(tmpdir):
"""Throw error when using 16 as Native CPU AMP only supports bfloat16."""

def test_precision_selection_raises(monkeypatch):
with pytest.raises(
MisconfigurationException,
match="CPU native amp only supports bfloat16. Please pass precision='bf16' to the Trainer.",
MisconfigurationException, match=r"precision=16, amp_type='apex'\)` but apex AMP not supported on CPU"
):
Trainer(amp_backend="apex", precision=16)

with pytest.warns(
UserWarning, match=r"precision=16\)` but native AMP is not supported on CPU. Using `precision='bf16"
), pytest.raises(MisconfigurationException, match="must install torch greater or equal to 1.10"):
Trainer(precision=16)

with pytest.raises(MisconfigurationException, match="must install torch greater or equal to 1.10"):
Trainer(precision="bf16")

with pytest.raises(MisconfigurationException, match=r"amp_type='apex', precision='bf16'\)` but it's not supported"):
Trainer(amp_backend="apex", precision="bf16")

with mock.patch("torch.cuda.device_count", return_value=1), pytest.raises(
MisconfigurationException, match="Sharded plugins are not supported with apex"
):
Trainer(amp_backend="apex", precision=16, gpus=1, accelerator="ddp_fully_sharded")

import pytorch_lightning.plugins.precision.apex_amp as apex

monkeypatch.setattr(apex, "_APEX_AVAILABLE", False)
with mock.patch("torch.cuda.device_count", return_value=1), pytest.raises(
MisconfigurationException, match="asked for Apex AMP but you have not installed it"
):
Trainer(
default_root_dir=tmpdir,
precision=16,
)
Trainer(amp_backend="apex", precision=16, gpus=1)