diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index dbe648fe6dcf0..807eb979e4943 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -14,7 +14,8 @@ import logging import os -from typing import List, Optional, Union +from collections import Counter +from typing import Dict, List, Optional, Union import torch @@ -304,6 +305,7 @@ def _check_config_and_set_final_flags( self._precision_flag = precision if plugins: + plugins_flags_types: Dict[str, int] = Counter() for plugin in plugins: if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies: self._strategy_flag = plugin @@ -311,15 +313,17 @@ def _check_config_and_set_final_flags( f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated" f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead." ) + plugins_flags_types[Strategy.__name__] += 1 elif isinstance(plugin, PrecisionPlugin): self._precision_plugin_flag = plugin - elif isinstance(plugin, str) and plugin in self._precision_types: - self._precision_flag = plugin + plugins_flags_types[PrecisionPlugin.__name__] += 1 elif isinstance(plugin, CheckpointIO): self.checkpoint_io = plugin + plugins_flags_types[CheckpointIO.__name__] += 1 elif isinstance(plugin, ClusterEnvironment): self._cluster_environment_flag = plugin + plugins_flags_types[ClusterEnvironment.__name__] += 1 elif isinstance(plugin, LayerSync): if sync_batchnorm and not isinstance(plugin, NativeSyncBatchNorm): raise MisconfigurationException( @@ -327,11 +331,20 @@ def _check_config_and_set_final_flags( " plugin, but this is not allowed. Choose one or the other." ) self._layer_sync = plugin + plugins_flags_types[NativeSyncBatchNorm.__name__] += 1 else: raise MisconfigurationException( - f"Found invalid type for plugin {plugin}. Expected a precision plugin or training strategy." + f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, " + "CheckpointIO, ClusterEnviroment, LayerSync, or Strategy." ) + duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1] + if duplicated_plugin_key: + raise MisconfigurationException( + f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`." + " Expected one value for each type at most." + ) + # handle the case when the user passes in a strategy instance which has an accelerator, precision, # checkpoint io or cluster env set up # TODO: @awaelchli improve the error messages below diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 8e835fbdf43cd..8e79ce1caa6b8 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -26,13 +26,14 @@ from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.gpu import GPUAccelerator -from pytorch_lightning.plugins import LayerSync, NativeSyncBatchNorm, PrecisionPlugin +from pytorch_lightning.plugins import DoublePrecisionPlugin, LayerSync, NativeSyncBatchNorm, PrecisionPlugin from pytorch_lightning.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment, ) +from pytorch_lightning.plugins.io import TorchCheckpointIO from pytorch_lightning.strategies import ( DataParallelStrategy, DDP2Strategy, @@ -1019,3 +1020,19 @@ def __init__(self, **kwargs): assert strategy._layer_sync is None Trainer(strategy=strategy, sync_batchnorm=True) assert isinstance(strategy._layer_sync, NativeSyncBatchNorm) + + +@pytest.mark.parametrize( + ["plugins", "expected"], + [ + ([LightningEnvironment(), SLURMEnvironment()], "ClusterEnvironment"), + ([TorchCheckpointIO(), TorchCheckpointIO()], "CheckpointIO"), + ( + [PrecisionPlugin(), DoublePrecisionPlugin(), LightningEnvironment(), SLURMEnvironment()], + "PrecisionPlugin, ClusterEnvironment", + ), + ], +) +def test_plugin_only_one_instance_for_one_type(plugins, expected): + with pytest.raises(MisconfigurationException, match=f"Received multiple values for {expected}"): + Trainer(plugins=plugins)