Skip to content

Commit db4e770

Browse files
authored
Validate the precision input earlier (#9763)
1 parent 6429de8 commit db4e770

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ def __init__(
128128
self.sync_batchnorm = sync_batchnorm
129129
self.benchmark = benchmark
130130
self.replace_sampler_ddp = replace_sampler_ddp
131+
if not PrecisionType.supported_type(precision):
132+
raise MisconfigurationException(
133+
f"Precision {repr(precision)} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
134+
)
131135
self.precision = precision
132136
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
133137
self.amp_level = amp_level
@@ -657,10 +661,6 @@ def select_precision_plugin(self) -> PrecisionPlugin:
657661

658662
return ApexMixedPrecisionPlugin(self.amp_level)
659663

660-
raise MisconfigurationException(
661-
f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
662-
)
663-
664664
def select_training_type_plugin(self) -> TrainingTypePlugin:
665665
if (
666666
isinstance(self.distributed_backend, Accelerator)

pytorch_lightning/utilities/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class PrecisionType(LightningEnum):
6262
FLOAT = "32"
6363
FULL = "64"
6464
BFLOAT = "bf16"
65+
MIXED = "mixed"
6566

6667
@staticmethod
6768
def supported_type(precision: Union[str, int]) -> bool:

tests/accelerators/test_accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin):
709709
@pytest.mark.parametrize("precision", [1, 12, "invalid"])
710710
def test_validate_precision_type(tmpdir, precision):
711711

712-
with pytest.raises(MisconfigurationException, match=f"Precision {precision} is invalid"):
712+
with pytest.raises(MisconfigurationException, match=f"Precision {repr(precision)} is invalid"):
713713
Trainer(precision=precision)
714714

715715

tests/utilities/test_enums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_consistency():
2727

2828

2929
def test_precision_supported_types():
30-
assert PrecisionType.supported_types() == ["16", "32", "64", "bf16"]
30+
assert PrecisionType.supported_types() == ["16", "32", "64", "bf16", "mixed"]
3131
assert PrecisionType.supported_type(16)
3232
assert PrecisionType.supported_type("16")
3333
assert not PrecisionType.supported_type(1)

0 commit comments

Comments
 (0)