Skip to content

Commit

Permalink
rebase and remove changelog
Browse files Browse the repository at this point in the history
  • Loading branch information
four4fish committed Mar 4, 2022
1 parent 69aaa1f commit b5a6d18
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,40 +305,40 @@ def _check_config_and_set_final_flags(
self._precision_flag = precision

if plugins:
plugins_flags_types_list = []
plugins_flags_types = Counter()
for plugin in plugins:
if isinstance(plugin, Strategy) or isinstance(plugin, str) and plugin in self._registered_strategies:
self._strategy_flag = plugin
rank_zero_deprecation(
f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated"
f" in v1.5 and will be removed in v1.7. Use `Trainer(strategy={plugin})` instead."
)
plugins_flags_types_list.append(Strategy.__name__)
plugins_flags_types[Strategy.__name__] += 1

elif isinstance(plugin, PrecisionPlugin):
self._precision_plugin_flag = plugin
plugins_flags_types_list.append(PrecisionPlugin.__name__)
plugins_flags_types[PrecisionPlugin.__name__] += 1
elif isinstance(plugin, CheckpointIO):
self.checkpoint_io = plugin
plugins_flags_types_list.append(CheckpointIO.__name__)
plugins_flags_types[CheckpointIO.__name__] += 1
elif isinstance(plugin, ClusterEnvironment):
self._cluster_environment_flag = plugin
plugins_flags_types_list.append("ClusterEnvironment")
plugins_flags_types[ClusterEnvironment.__name__] += 1
elif isinstance(plugin, LayerSync):
if sync_batchnorm and not isinstance(plugin, NativeSyncBatchNorm):
raise MisconfigurationException(
f"You set `Trainer(sync_batchnorm=True)` and provided a `{plugin.__class__.__name__}`"
" plugin, but this is not allowed. Choose one or the other."
)
self._layer_sync = plugin
plugins_flags_types_list.append(ClusterEnvironment.__name__)
plugins_flags_types[NativeSyncBatchNorm.__name__] += 1
else:
raise MisconfigurationException(
f"Found invalid type for plugin {plugin}. Expected PrecisionPlugin, "
"CheckpointIO plugin, ClusterEnviroment plugin or a Strategy."
)

duplicated_plugin_key = [k for k, v in Counter(plugins_flags_types_list).items() if v > 1]
duplicated_plugin_key = [k for k, v in plugins_flags_types.items() if v > 1]
if duplicated_plugin_key:
raise MisconfigurationException(
f"Received multiple values for {', '.join(duplicated_plugin_key)} flags in `plugins`."
Expand Down

0 comments on commit b5a6d18

Please sign in to comment.