Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Shorthand notation to instantiate callbacks [3/3] #8815

Merged
merged 88 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d7f00be
add registries
tchaton Aug 9, 2021
a07d305
simplify LightningCLI with defaults
tchaton Aug 9, 2021
ce39c47
cleanup
tchaton Aug 9, 2021
3081475
update
tchaton Aug 9, 2021
51f82d5
updates
tchaton Aug 10, 2021
7197d6e
cleanup
tchaton Aug 10, 2021
9a6e81e
update on comments
tchaton Aug 10, 2021
41f5d78
update
tchaton Aug 10, 2021
06e4999
cleanup
tchaton Aug 10, 2021
e91ea47
update on comments
tchaton Aug 10, 2021
3f35ecd
Merge branch 'master' into lightning_cli_registries
tchaton Aug 10, 2021
705c0bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2021
78a4398
add docs
tchaton Aug 10, 2021
e96dc28
doc updates
tchaton Aug 10, 2021
631aa72
update
tchaton Aug 10, 2021
2fc3c0a
update
tchaton Aug 10, 2021
43dd8b4
resolve comments
tchaton Aug 10, 2021
c6ae669
comment
tchaton Aug 10, 2021
5c21b1c
add comment
tchaton Aug 10, 2021
b370deb
typo
tchaton Aug 10, 2021
f8e7ca7
update on comments
tchaton Aug 11, 2021
e428d2f
resolve bug
tchaton Aug 11, 2021
3e97905
typo
tchaton Aug 11, 2021
3b1bdb6
update
tchaton Aug 11, 2021
0d1db29
resolve comments
tchaton Aug 11, 2021
3d35c82
add unittesting
tchaton Aug 11, 2021
4c0f960
resolve tests
tchaton Aug 11, 2021
d3a62ca
resolve comments
tchaton Aug 12, 2021
39781a1
update on comments
tchaton Aug 13, 2021
68c03de
doc updates
tchaton Aug 13, 2021
b01828b
update
tchaton Aug 13, 2021
d213c73
Merge branch 'master' into lightning_cli_registries
tchaton Aug 13, 2021
5935ec4
update on comments
tchaton Aug 17, 2021
b6616f0
Merge branch 'lightning_cli_registries' of https://github.com/PyTorch…
tchaton Aug 17, 2021
0d89423
Merge branch 'master' into lightning_cli_registries
carmocca Aug 19, 2021
37fd679
Fix mypy
carmocca Aug 19, 2021
f16db3d
Revert unrelated change which had broken mypy
carmocca Aug 19, 2021
572488c
Convert to staticmethod
carmocca Aug 19, 2021
2fc4608
Replace context managers for functional static transformations
carmocca Aug 19, 2021
9f383dc
Split tests
carmocca Aug 19, 2021
2a7dfa8
Refactor optimizer tests
carmocca Aug 19, 2021
423ab7b
Cleaning tests
carmocca Aug 19, 2021
7c2e39e
Delete broken test
carmocca Aug 19, 2021
048e159
Docs improvements
carmocca Aug 19, 2021
86fce55
Docs improvements
carmocca Aug 19, 2021
624b0d8
Restructure docs
carmocca Aug 19, 2021
2cc0dc5
Docs for callbacks
carmocca Aug 19, 2021
f9b49fe
Add reload test when add_optimizer_args is added by the user
carmocca Aug 19, 2021
afcc4ba
Add failing config test - needs to be fixed
carmocca Aug 19, 2021
9f41b88
Merge branch 'master' into lightning_cli_registries
carmocca Aug 28, 2021
0ed4ae8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
4dd0732
Use property
carmocca Aug 19, 2021
e0fae4f
Fixes after merge
carmocca Aug 28, 2021
4f053bb
Merge branch 'master' into lightning_cli_registries
carmocca Sep 15, 2021
a22fdb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2021
160b3f6
Update jsonargparse version
carmocca Sep 15, 2021
f185c2d
Use properties in registry
carmocca Sep 15, 2021
803385c
Keep hacks together
carmocca Sep 15, 2021
8eb8b05
Add FIXMEs
carmocca Sep 15, 2021
9d84127
add_class_choices
carmocca Sep 15, 2021
33ff2f4
Merge branch 'master' into lightning_cli_registries
carmocca Sep 15, 2021
cf82e1a
Remove contains registry. Avoid nested_key clash for optimizers and l…
carmocca Sep 15, 2021
b1cd083
Remove sanitize argv
carmocca Sep 15, 2021
95d31a7
Better support for new callback format
carmocca Sep 16, 2021
231e0ed
Avoid evaluating
carmocca Sep 16, 2021
2af596f
Minor cleaning
carmocca Sep 16, 2021
6add619
Mark argv as private
carmocca Sep 16, 2021
525358a
Fix mypy
carmocca Sep 16, 2021
84b8120
Fix mypy
carmocca Sep 16, 2021
7e48c0e
Fix mypy
carmocca Sep 16, 2021
40ce3c7
Merge branch 'master' into lightning_cli_registries
carmocca Sep 16, 2021
3e77e8e
Support shorthand notation to instantiate optimizers and learning rat…
carmocca Sep 16, 2021
1512a80
Update CHANGELOG
carmocca Sep 16, 2021
c6b86b1
Fix install
carmocca Sep 16, 2021
6f1600c
Fix install
carmocca Sep 16, 2021
a3a791f
Use release
carmocca Sep 16, 2021
f67a90f
Merge branch 'feat/cli-shorthand-optimizers' into lightning_cli_regis…
carmocca Sep 16, 2021
fedae46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2021
ee7a068
Introduce set_choices
carmocca Sep 16, 2021
6e67617
Undo change
carmocca Sep 16, 2021
e7f6d61
Replace add_class_choices with set_choices
carmocca Sep 16, 2021
8e87359
Replace add_class_choices with set_choices
carmocca Sep 16, 2021
c74426b
Merge
carmocca Sep 16, 2021
66cdb52
Docstrings
carmocca Sep 16, 2021
9217304
Merge branch 'master' into lightning_cli_registries
carmocca Sep 17, 2021
7b50401
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2021
1406be9
Fix mypy
carmocca Sep 17, 2021
a000446
Undo change
carmocca Sep 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 114 additions & 1 deletion pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
import sys
from argparse import Namespace
from contextlib import contextmanager
from types import MethodType
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
from unittest import mock

from torch.optim import Optimizer

from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, warnings
from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES, OPTIMIZER_REGISTRIES, SCHEDULER_REGISTRIES
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -332,6 +337,16 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:

def link_optimizers_and_lr_schedulers(self) -> None:
"""Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``."""
if any(
True for v in sys.argv for optim_name in OPTIMIZER_REGISTRIES.keys() if f"--optimizer={optim_name}" in v
tchaton marked this conversation as resolved.
Show resolved Hide resolved
):
optimizers = tuple(v for v in OPTIMIZER_REGISTRIES.values())
self.parser.add_optimizer_args(optimizers)

if any(True for v in sys.argv for sch_name in SCHEDULER_REGISTRIES.keys() if f"-lr_scheduler={sch_name}" in v):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
lr_schdulers = tuple(v for v in SCHEDULER_REGISTRIES.values())
self.parser.add_lr_scheduler_args(lr_schdulers)

for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items():
if link_to == "AUTOMATIC":
continue
Expand All @@ -341,9 +356,107 @@ def link_optimizers_and_lr_schedulers(self) -> None:
add_class_path = _add_class_path_generator(class_type)
self.parser.link_arguments(key, link_to, compute_fn=add_class_path)

@contextmanager
def prepare_optimizer(self):
"""
This context manager is used to simplify optimizer instantiation for Lightning users.
"""
optimizer_args = [v for v in sys.argv if v.startswith("--optimizer")]
should_replace = len(optimizer_args) > 0 and not any(v for v in optimizer_args if "class_path" in v)
if should_replace:
optimizer_arg = {}
init_args = {}
for v in optimizer_args:
if "optimizer." in v:
arg_path, value = v.split("=")
init_args[arg_path.split(".")[-1]] = value
else:
class_name = v.split("=")[-1]
optim_cls = OPTIMIZER_REGISTRIES[class_name]
optimizer_arg["class_path"] = optim_cls.__module__ + "." + class_name
optimizer_arg["init_args"] = init_args
argv = [v for v in sys.argv if not v.startswith("--optimizer")] + [
f"--optimizer={json.dumps(optimizer_arg)}"
]
with mock.patch("sys.argv", argv):
yield
else:
yield

@contextmanager
def prepare_callbacks(self):
"""
This context manager is used to simplify callbacks instantiation for Lightning users.
"""
all_callbacks_args = [v for v in sys.argv if v.startswith("--trainer.callbacks")]
callbacks_args = [v for v in sys.argv if v.startswith("--trainer.callbacks=")]
num_callbacks = len(callbacks_args)
should_replace = len(all_callbacks_args) > 0 and not any(v for v in all_callbacks_args if "class_path" in v)
if should_replace:
# FIXME: Add support for combining callbacks.
callbacks_argv = {}
init_args = {}
map_callback_args = {idx: [] for idx in range(num_callbacks)}
counter = -1
callback_out = []
for v in all_callbacks_args:
if "--trainer.callbacks=" in v:
counter += 1
map_callback_args[counter].append(v)
callback_out = []
for callback_idx in range(num_callbacks):
callback_args = map_callback_args[callback_idx]
callbacks_argv = {}
init_args = {}
for callback_arg in callback_args:
if "--trainer.callbacks=" in callback_arg:
class_name = callback_arg.split("=")[-1]
callback_cls = CALLBACK_REGISTRIES[class_name]
callbacks_argv["class_path"] = callback_cls.__module__ + "." + class_name
else:
arg_path, value = callback_arg.split("=")
init_args[arg_path.split(".")[-1]] = value
callbacks_argv["init_args"] = init_args
callback_out.append(callbacks_argv)
argv = [v for v in sys.argv if not v.startswith("--trainer.callbacks")] + [
f"--trainer.callbacks={json.dumps(callback_out)}"
]
with mock.patch("sys.argv", argv):
yield
else:
yield

@contextmanager
def prepare_schedulers(self):
"""
This context manager is used to simplify schedulers instantiation for Lightning users.
"""
lr_scheduler_args = [v for v in sys.argv if v.startswith("--lr_scheduler")]
should_replace = len(lr_scheduler_args) > 0 and not any(v for v in lr_scheduler_args if "class_path" in v)
if should_replace:
lr_scheduler_arg = {}
init_args = {}
for v in lr_scheduler_args:
if "lr_scheduler." in v:
arg_path, value = v.split("=")
init_args[arg_path.split(".")[-1]] = value
else:
class_name = v.split("=")[-1]
optim_cls = SCHEDULER_REGISTRIES[class_name]
lr_scheduler_arg["class_path"] = optim_cls.__module__ + "." + class_name
lr_scheduler_arg["init_args"] = init_args
argv = [v for v in sys.argv if not v.startswith("--lr_scheduler")] + [
f"--lr_scheduler={json.dumps(lr_scheduler_arg)}"
]
with mock.patch("sys.argv", argv):
yield
else:
yield

def parse_arguments(self, parser: LightningArgumentParser) -> None:
"""Parses command line arguments and stores it in ``self.config``."""
self.config = parser.parse_args()
with self.prepare_optimizer(), self.prepare_callbacks(), self.prepare_schedulers():
self.config = parser.parse_args()

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""
Expand Down
79 changes: 79 additions & 0 deletions pytorch_lightning/utilities/cli_registries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The PyTorch Lightning team.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections import UserDict
from typing import Callable, List, Optional, Type

import torch

import pytorch_lightning as pl
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class Registry(UserDict):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def __call__(
self,
cls: Optional[Type] = None,
key: Optional[str] = None,
override: bool = False,
) -> Callable:
"""
Registers a plugin mapped to a name and with required metadata.

Args:
key : the name that identifies a plugin, e.g. "deepspeed_stage_3"
value : plugin class
"""
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if key is None:
key = cls.__name__
elif not isinstance(key, str):
raise TypeError(f"`key` must be a str, found {key}")

if key in self and not override:
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.")

def do_register(key, cls) -> Callable:
self[key] = cls
return cls

do_register(key, cls)

return do_register

def register_package(self, module, base_cls: Type) -> None:
for obj_name in dir(module):
obj_cls = getattr(module, obj_name)
if inspect.isclass(obj_cls) and issubclass(obj_cls, base_cls):
self(cls=obj_cls)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def remove(self, name: str) -> None:
"""Removes the registered plugin by name"""
self.pop(name)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def available_objects(self) -> List:
"""Returns a list of registered plugins"""
return list(self.keys())

def __str__(self) -> str:
return "Registered Plugins: {}".format(", ".join(self.keys()))


CALLBACK_REGISTRIES = Registry()
CALLBACK_REGISTRIES.register_package(pl.callbacks, pl.callbacks.Callback)

OPTIMIZER_REGISTRIES = Registry()
OPTIMIZER_REGISTRIES.register_package(torch.optim, torch.optim.Optimizer)

SCHEDULER_REGISTRIES = Registry()
SCHEDULER_REGISTRIES.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
57 changes: 57 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import instantiate_class, LightningArgumentParser, LightningCLI, SaveConfigCallback
from pytorch_lightning.utilities.cli_registries import CALLBACK_REGISTRIES
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -687,3 +688,59 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
assert isinstance(cli.model.optim1, torch.optim.Adam)
assert isinstance(cli.model.optim2, torch.optim.SGD)
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)


@CALLBACK_REGISTRIES
class CustomCallback(Callback):
pass


def test_registries(tmpdir):

assert CALLBACK_REGISTRIES.available_objects() == [
"BackboneFinetuning",
"BaseFinetuning",
"BasePredictionWriter",
"Callback",
"EarlyStopping",
"GPUStatsMonitor",
"GradientAccumulationScheduler",
"LambdaCallback",
"LearningRateMonitor",
"ModelCheckpoint",
"ModelPruning",
"ProgressBar",
"ProgressBarBase",
"QuantizationAwareTraining",
"StochasticWeightAveraging",
"Timer",
"XLAStatsMonitor",
"CustomCallback",
]

class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
pass

class TestModel(BoringModel):
def __init__(self):
super().__init__()

cli_args = [
f"--trainer.default_root_dir={tmpdir}",
"--trainer.max_epochs=1",
"--optimizer=Adam",
"--optimizer.lr=0.0001",
"--trainer.callbacks=LearningRateMonitor",
"--trainer.callbacks.logging_interval=epoch",
"--trainer.callbacks.log_momentum=True",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=loss",
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"--lr_scheduler=StepLR",
"--lr_scheduler.step_size=50",
]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(TestModel)

assert isinstance(cli.trainer.optimizers[0], torch.optim.Adam)