Skip to content

Commit

Permalink
Only allow one value for each plugin type in plugins flag (#12083)
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish authored Mar 11, 2022
1 parent c90174c commit 4d74f37
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
21 changes: 17 additions & 4 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import logging
import os
from typing import List, Optional, Union
from collections import Counter
from typing import Dict, List, Optional, Union

import torch

Expand Down Expand Up @@ -304,34 +305,46 @@ def _check_config_and_set_final_flags(
self._precision_flag = precision

if plugins:
plugins_flags_types: Dict[str, int] = Counter()
for plugin in plugins:
if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies:
self._strategy_flag = plugin
rank_zero_deprecation(
f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead."
)
plugins_flags_types[Strategy.__name__] += 1

elif isinstance(plugin, PrecisionPlugin):
self._precision_plugin_flag = plugin
elif isinstance(plugin, str) and plugin in self._precision_types:
self._precision_flag = plugin
plugins_flags_types[PrecisionPlugin.__name__] += 1
elif isinstance(plugin, CheckpointIO):
self.checkpoint_io = plugin
plugins_flags_types[CheckpointIO.__name__] += 1
elif isinstance(plugin, ClusterEnvironment):
self._cluster_environment_flag = plugin
plugins_flags_types[ClusterEnvironment.__name__] += 1
elif isinstance(plugin, LayerSync):
if sync_batchnorm and not isinstance(plugin, NativeSyncBatchNorm):
raise MisconfigurationException(
f"You set `Trainer(sync_batchnorm=True)` and provided a `{plugin.__class__.__name__}`"
" plugin, but this is not allowed. Choose one or the other."
)
self._layer_sync = plugin
plugins_flags_types[NativeSyncBatchNorm.__name__] += 1
else:
raise MisconfigurationException(
f"Found invalid type for plugin {plugin}. Expected a precision plugin or training strategy."
f"Found invalid type for plugin {plugin}. Expected one of: PrecisionPlugin, "
"CheckpointIO, ClusterEnviroment, LayerSync, or Strategy."
)

duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1]
if duplicated_plugin_key:
raise MisconfigurationException(
f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`."
" Expected one value for each type at most."
)

# handle the case when the user passes in a strategy instance which has an accelerator, precision,
# checkpoint io or cluster env set up
# TODO: @awaelchli improve the error messages below
Expand Down
19 changes: 18 additions & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.plugins import LayerSync, NativeSyncBatchNorm, PrecisionPlugin
from pytorch_lightning.plugins import DoublePrecisionPlugin, LayerSync, NativeSyncBatchNorm, PrecisionPlugin
from pytorch_lightning.plugins.environments import (
KubeflowEnvironment,
LightningEnvironment,
SLURMEnvironment,
TorchElasticEnvironment,
)
from pytorch_lightning.plugins.io import TorchCheckpointIO
from pytorch_lightning.strategies import (
DataParallelStrategy,
DDP2Strategy,
Expand Down Expand Up @@ -1019,3 +1020,19 @@ def __init__(self, **kwargs):
assert strategy._layer_sync is None
Trainer(strategy=strategy, sync_batchnorm=True)
assert isinstance(strategy._layer_sync, NativeSyncBatchNorm)


@pytest.mark.parametrize(
["plugins", "expected"],
[
([LightningEnvironment(), SLURMEnvironment()], "ClusterEnvironment"),
([TorchCheckpointIO(), TorchCheckpointIO()], "CheckpointIO"),
(
[PrecisionPlugin(), DoublePrecisionPlugin(), LightningEnvironment(), SLURMEnvironment()],
"PrecisionPlugin, ClusterEnvironment",
),
],
)
def test_plugin_only_one_instance_for_one_type(plugins, expected):
with pytest.raises(MisconfigurationException, match=f"Received multiple values for {expected}"):
Trainer(plugins=plugins)

0 comments on commit 4d74f37

Please sign in to comment.