Skip to content

Commit

Permalink
Validate the precision input earlier (#9763)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 15, 2021
1 parent 6429de8 commit db4e770
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def __init__(
self.sync_batchnorm = sync_batchnorm
self.benchmark = benchmark
self.replace_sampler_ddp = replace_sampler_ddp
if not PrecisionType.supported_type(precision):
raise MisconfigurationException(
f"Precision {repr(precision)} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
)
self.precision = precision
self.amp_type = amp_type.lower() if isinstance(amp_type, str) else None
self.amp_level = amp_level
Expand Down Expand Up @@ -657,10 +661,6 @@ def select_precision_plugin(self) -> PrecisionPlugin:

return ApexMixedPrecisionPlugin(self.amp_level)

raise MisconfigurationException(
f"Precision {self.precision} is invalid. Allowed precision values: {PrecisionType.supported_types()}"
)

def select_training_type_plugin(self) -> TrainingTypePlugin:
if (
isinstance(self.distributed_backend, Accelerator)
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class PrecisionType(LightningEnum):
FLOAT = "32"
FULL = "64"
BFLOAT = "bf16"
MIXED = "mixed"

@staticmethod
def supported_type(precision: Union[str, int]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def test_device_type_when_training_plugin_gpu_passed(tmpdir, plugin):
@pytest.mark.parametrize("precision", [1, 12, "invalid"])
def test_validate_precision_type(tmpdir, precision):

with pytest.raises(MisconfigurationException, match=f"Precision {precision} is invalid"):
with pytest.raises(MisconfigurationException, match=f"Precision {repr(precision)} is invalid"):
Trainer(precision=precision)


Expand Down
2 changes: 1 addition & 1 deletion tests/utilities/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_consistency():


def test_precision_supported_types():
assert PrecisionType.supported_types() == ["16", "32", "64", "bf16"]
assert PrecisionType.supported_types() == ["16", "32", "64", "bf16", "mixed"]
assert PrecisionType.supported_type(16)
assert PrecisionType.supported_type("16")
assert not PrecisionType.supported_type(1)
Expand Down

0 comments on commit db4e770

Please sign in to comment.