Skip to content

Commit

Permalink
Update Accelerator Connector for Registry (#7214)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikb11 authored May 3, 2021
1 parent b7a4448 commit 6d7c6d6
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 14 deletions.
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/plugins_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def register(

def do_register(plugin: Callable) -> Callable:
data["plugin"] = plugin
data["distributed_backend"] = plugin.distributed_backend
self[name] = data
return plugin

Expand Down
38 changes: 24 additions & 14 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ def __init__(
self._training_type_plugin: Optional[TrainingTypePlugin] = None
self._cluster_environment: Optional[ClusterEnvironment] = None

plugins = plugins if plugins is not None else []

if isinstance(plugins, str):
plugins = [plugins]

if not isinstance(plugins, Sequence):
plugins = [plugins]

self.plugins = plugins

# for gpus allow int, string and gpu list
if auto_select_gpus and isinstance(gpus, int):
self.gpus = pick_multiple_gpus(gpus)
Expand All @@ -121,7 +131,7 @@ def __init__(
self.set_distributed_mode()
self.configure_slurm_ddp()

self.handle_given_plugins(plugins)
self.handle_given_plugins()

self.accelerator = self.select_accelerator()

Expand All @@ -148,22 +158,13 @@ def __init__(

self.replace_sampler_ddp = replace_sampler_ddp

def handle_given_plugins(
self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]]
):
plugins = plugins if plugins is not None else []

if isinstance(plugins, str):
plugins = [plugins]

if not isinstance(plugins, Sequence):
plugins = [plugins]
def handle_given_plugins(self) -> None:

training_type = None
precision = None
cluster_environment = None

for plug in plugins:
for plug in self.plugins:
if isinstance(plug, str) and plug in TrainingTypePluginsRegistry:
if training_type is None:
training_type = TrainingTypePluginsRegistry.get(plug)
Expand All @@ -173,7 +174,7 @@ def handle_given_plugins(
' Found more than 1 training type plugin:'
f' {TrainingTypePluginsRegistry[plug]["plugin"]} registered to {plug}'
)
elif isinstance(plug, str):
if isinstance(plug, str):
# Reset the distributed type as the user has overridden training type
# via the plugins argument
self._distrib_type = None
Expand Down Expand Up @@ -310,6 +311,10 @@ def parallel_devices(self) -> List[Union[torch.device, int]]:
def root_gpu(self) -> Optional[int]:
return self.accelerator.root_device.index if not isinstance(self.accelerator, TPUAccelerator) else None

@property
def is_training_type_in_plugins(self) -> bool:
return any(isinstance(plug, str) and plug in TrainingTypePluginsRegistry for plug in self.plugins)

@property
def is_using_torchelastic(self) -> bool:
"""
Expand Down Expand Up @@ -492,7 +497,12 @@ def select_cluster_environment(self) -> ClusterEnvironment:

def set_distributed_mode(self, distributed_backend: Optional[str] = None):

if distributed_backend is not None:
if distributed_backend is None and self.is_training_type_in_plugins:
return

if distributed_backend is not None and distributed_backend in TrainingTypePluginsRegistry:
self.distributed_backend = TrainingTypePluginsRegistry[distributed_backend]["distributed_backend"]
elif distributed_backend is not None:
self.distributed_backend = distributed_backend

if isinstance(self.distributed_backend, Accelerator):
Expand Down
3 changes: 3 additions & 0 deletions tests/plugins/test_plugins_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def test_training_type_plugins_registry_with_new_plugin():

class TestPlugin:

distributed_backend = "test_plugin"

def __init__(self, param1, param2):
self.param1 = param1
self.param2 = param2
Expand All @@ -37,6 +39,7 @@ def __init__(self, param1, param2):
assert plugin_name in TrainingTypePluginsRegistry
assert TrainingTypePluginsRegistry[plugin_name]["description"] == plugin_description
assert TrainingTypePluginsRegistry[plugin_name]["init_params"] == {"param1": "abc", "param2": 123}
assert TrainingTypePluginsRegistry[plugin_name]["distributed_backend"] == "test_plugin"
assert isinstance(TrainingTypePluginsRegistry.get(plugin_name), TestPlugin)

TrainingTypePluginsRegistry.remove(plugin_name)
Expand Down

0 comments on commit 6d7c6d6

Please sign in to comment.