Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Shorthand notation to instantiate callbacks [3/3] #8815

Merged
merged 88 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d7f00be
add registries
tchaton Aug 9, 2021
a07d305
simplify LightningCLI with defaults
tchaton Aug 9, 2021
ce39c47
cleanup
tchaton Aug 9, 2021
3081475
update
tchaton Aug 9, 2021
51f82d5
updates
tchaton Aug 10, 2021
7197d6e
cleanup
tchaton Aug 10, 2021
9a6e81e
update on comments
tchaton Aug 10, 2021
41f5d78
update
tchaton Aug 10, 2021
06e4999
cleanup
tchaton Aug 10, 2021
e91ea47
update on comments
tchaton Aug 10, 2021
3f35ecd
Merge branch 'master' into lightning_cli_registries
tchaton Aug 10, 2021
705c0bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2021
78a4398
add docs
tchaton Aug 10, 2021
e96dc28
doc updates
tchaton Aug 10, 2021
631aa72
update
tchaton Aug 10, 2021
2fc3c0a
update
tchaton Aug 10, 2021
43dd8b4
resolve comments
tchaton Aug 10, 2021
c6ae669
comment
tchaton Aug 10, 2021
5c21b1c
add comment
tchaton Aug 10, 2021
b370deb
typo
tchaton Aug 10, 2021
f8e7ca7
update on comments
tchaton Aug 11, 2021
e428d2f
resolve bug
tchaton Aug 11, 2021
3e97905
typo
tchaton Aug 11, 2021
3b1bdb6
update
tchaton Aug 11, 2021
0d1db29
resolve comments
tchaton Aug 11, 2021
3d35c82
add unittesting
tchaton Aug 11, 2021
4c0f960
resolve tests
tchaton Aug 11, 2021
d3a62ca
resolve comments
tchaton Aug 12, 2021
39781a1
update on comments
tchaton Aug 13, 2021
68c03de
doc updates
tchaton Aug 13, 2021
b01828b
update
tchaton Aug 13, 2021
d213c73
Merge branch 'master' into lightning_cli_registries
tchaton Aug 13, 2021
5935ec4
update on comments
tchaton Aug 17, 2021
b6616f0
Merge branch 'lightning_cli_registries' of https://github.com/PyTorch…
tchaton Aug 17, 2021
0d89423
Merge branch 'master' into lightning_cli_registries
carmocca Aug 19, 2021
37fd679
Fix mypy
carmocca Aug 19, 2021
f16db3d
Revert unrelated change which had broken mypy
carmocca Aug 19, 2021
572488c
Convert to staticmethod
carmocca Aug 19, 2021
2fc4608
Replace context managers for functional static transformations
carmocca Aug 19, 2021
9f383dc
Split tests
carmocca Aug 19, 2021
2a7dfa8
Refactor optimizer tests
carmocca Aug 19, 2021
423ab7b
Cleaning tests
carmocca Aug 19, 2021
7c2e39e
Delete broken test
carmocca Aug 19, 2021
048e159
Docs improvements
carmocca Aug 19, 2021
86fce55
Docs improvements
carmocca Aug 19, 2021
624b0d8
Restructure docs
carmocca Aug 19, 2021
2cc0dc5
Docs for callbacks
carmocca Aug 19, 2021
f9b49fe
Add reload test when add_optimizer_args is added by the user
carmocca Aug 19, 2021
afcc4ba
Add failing config test - needs to be fixed
carmocca Aug 19, 2021
9f41b88
Merge branch 'master' into lightning_cli_registries
carmocca Aug 28, 2021
0ed4ae8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
4dd0732
Use property
carmocca Aug 19, 2021
e0fae4f
Fixes after merge
carmocca Aug 28, 2021
4f053bb
Merge branch 'master' into lightning_cli_registries
carmocca Sep 15, 2021
a22fdb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2021
160b3f6
Update jsonargparse version
carmocca Sep 15, 2021
f185c2d
Use properties in registry
carmocca Sep 15, 2021
803385c
Keep hacks together
carmocca Sep 15, 2021
8eb8b05
Add FIXMEs
carmocca Sep 15, 2021
9d84127
add_class_choices
carmocca Sep 15, 2021
33ff2f4
Merge branch 'master' into lightning_cli_registries
carmocca Sep 15, 2021
cf82e1a
Remove contains registry. Avoid nested_key clash for optimizers and l…
carmocca Sep 15, 2021
b1cd083
Remove sanitize argv
carmocca Sep 15, 2021
95d31a7
Better support for new callback format
carmocca Sep 16, 2021
231e0ed
Avoid evaluating
carmocca Sep 16, 2021
2af596f
Minor cleaning
carmocca Sep 16, 2021
6add619
Mark argv as private
carmocca Sep 16, 2021
525358a
Fix mypy
carmocca Sep 16, 2021
84b8120
Fix mypy
carmocca Sep 16, 2021
7e48c0e
Fix mypy
carmocca Sep 16, 2021
40ce3c7
Merge branch 'master' into lightning_cli_registries
carmocca Sep 16, 2021
3e77e8e
Support shorthand notation to instantiate optimizers and learning rat…
carmocca Sep 16, 2021
1512a80
Update CHANGELOG
carmocca Sep 16, 2021
c6b86b1
Fix install
carmocca Sep 16, 2021
6f1600c
Fix install
carmocca Sep 16, 2021
a3a791f
Use release
carmocca Sep 16, 2021
f67a90f
Merge branch 'feat/cli-shorthand-optimizers' into lightning_cli_regis…
carmocca Sep 16, 2021
fedae46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2021
ee7a068
Introduce set_choices
carmocca Sep 16, 2021
6e67617
Undo change
carmocca Sep 16, 2021
e7f6d61
Replace add_class_choices with set_choices
carmocca Sep 16, 2021
8e87359
Replace add_class_choices with set_choices
carmocca Sep 16, 2021
c74426b
Merge
carmocca Sep 16, 2021
66cdb52
Docstrings
carmocca Sep 16, 2021
9217304
Merge branch 'master' into lightning_cli_registries
carmocca Sep 17, 2021
7b50401
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2021
1406be9
Fix mypy
carmocca Sep 17, 2021
a000446
Undo change
carmocca Sep 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/code-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ jobs:
run: |
grep mypy requirements/test.txt | xargs -0 pip install
pip list
- run: mypy
- run: mypy --install-types --non-interactive
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Support passing lists of callbacks via command line ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815))


- Fault-tolerant training:
Expand Down
62 changes: 58 additions & 4 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest import mock
from typing import List
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, Callback


class NoFitTrainer(Trainer):
Expand Down Expand Up @@ -371,6 +371,59 @@ Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.tr
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured
the same way using :code:`class_path` and :code:`init_args`.

For callbacks in particular, Lightning simplifies the command line so that only
the :class:`~pytorch_lightning.callbacks.Callback` name is required.
The argument's order matters and the user needs to pass the arguments in the following way.

.. code-block:: bash

$ python ... \
--trainer.callbacks={CALLBACK_1_NAME} \
--trainer.callbacks.{CALLBACK_1_ARGS_1}=... \
--trainer.callbacks.{CALLBACK_1_ARGS_2}=... \
...
--trainer.callbacks={CALLBACK_N_NAME} \
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
...

Here is an example:

.. code-block:: bash

$ python ... \
--trainer.callbacks=EarlyStopping \
--trainer.callbacks.patience=5 \
--trainer.callbacks=LearningRateMonitor \
--trainer.callbacks.logging_interval=epoch

Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification
as described above:

.. code-block:: python

from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY


@CALLBACK_REGISTRY
class CustomCallback(Callback):
...


cli = LightningCLI(...)

.. code-block:: bash

$ python ... --trainer.callbacks=CustomCallback ...

This callback will be included in the generated config:

.. code-block:: yaml

trainer:
callbacks:
- class_path: your_class_path.CustomCallback
init_args:
...

Multiple models and/or datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -517,9 +570,10 @@ instantiating the trainer class can be found in :code:`self.config['fit']['train
Configurable callbacks
^^^^^^^^^^^^^^^^^^^^^^

As explained previously, any callback can be added by including it in the config via :code:`class_path` and
:code:`init_args` entries. However, there are other cases in which a callback should always be present and be
configurable. This can be implemented as follows:
As explained previously, any Lightning callback can be added by passing it through command line or
including it in the config via :code:`class_path` and :code:`init_args` entries.
However, there are other cases in which a callback should always be present and be configurable.
This can be implemented as follows:

.. testcode::

Expand Down
88 changes: 82 additions & 6 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from unittest import mock

import torch
import yaml
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, rank_zero_warn, warnings
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -83,12 +85,15 @@ def __str__(self) -> str:
LR_SCHEDULER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)

CALLBACK_REGISTRY = _Registry()
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""

# use class attribute because `parse_args` is only called on the main parser
_choices: Dict[str, Tuple[Type, ...]] = {}
_choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {}

def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.
Expand Down Expand Up @@ -202,23 +207,35 @@ def add_lr_scheduler_args(

def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
argv = sys.argv
for k, classes in self._choices.items():
for k, v in self._choices.items():
if not any(arg.startswith(f"--{k}") for arg in argv):
# the key wasn't passed - maybe defined in a config, maybe it's optional
continue
argv = self._convert_argv_issue_84(classes, k, argv)
classes, is_list = v
# knowing whether the argument is a list type automatically would be too complex
if is_list:
argv = self._convert_argv_issue_85(classes, k, argv)
else:
argv = self._convert_argv_issue_84(classes, k, argv)
self._choices.clear()
with mock.patch("sys.argv", argv):
return super().parse_args(*args, **kwargs)

def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None:
self._choices[nested_key] = classes
def set_choices(self, nested_key: str, classes: Tuple[Type, ...], is_list: bool = False) -> None:
"""Adds support for shorthand notation for a particular nested key.

Args:
nested_key: The key whose choices will be set.
classes: A tuple of classes to choose from.
is_list: Whether the argument is a ``List[object]`` type.
"""
self._choices[nested_key] = (classes, is_list)

@staticmethod
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.

This should be removed once implemented.
Adds support for shorthand notation for ``object`` arguments.
"""
passed_args, clean_argv = {}, []
argv_key = f"--{nested_key}"
Expand Down Expand Up @@ -259,6 +276,64 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis
raise ValueError(f"Could not generate a config for {repr(argv_class)}")
return clean_argv + [argv_key, config]

@staticmethod
def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/85.

Adds support for shorthand notation for ``List[object]`` arguments.
"""
passed_args, clean_argv = [], []
passed_configs = {}
argv_key = f"--{nested_key}"
# get the argv args for this nested key
i = 0
while i < len(argv):
arg = argv[i]
if arg.startswith(argv_key):
if "=" in arg:
key, value = arg.split("=")
else:
key = arg
i += 1
value = argv[i]
if "class_path" in value:
# the user passed a config as a dict
passed_configs[key] = yaml.safe_load(value)
else:
passed_args.append((key, value))
else:
clean_argv.append(arg)
i += 1
# generate the associated config file
config = []
i, n = 0, len(passed_args)
while i < n - 1:
ki, vi = passed_args[i]
# convert class name to class path
for cls in classes:
if cls.__name__ == vi:
cls_type = cls
break
else:
raise ValueError(f"Could not generate a config for {repr(vi)}")
config.append(_global_add_class_path(cls_type))
# get any init args
j = i + 1 # in case the j-loop doesn't run
for j in range(i + 1, n):
kj, vj = passed_args[j]
if ki == kj:
break
if kj.startswith(ki):
init_arg_name = kj.split(".")[-1]
config[-1]["init_args"][init_arg_name] = vj
i = j
# update at the end to preserve the order
for k, v in passed_configs.items():
config.extend(v)
if not config:
return clean_argv
return clean_argv + [argv_key, str(config)]


class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.
Expand Down Expand Up @@ -430,6 +505,7 @@ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> No
def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"""Adds arguments from the core classes to the parser."""
parser.add_lightning_class_args(self.trainer_class, "trainer")
parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True)
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
parser.set_defaults(trainer_defaults)
parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)
Expand Down
100 changes: 99 additions & 1 deletion tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import (
CALLBACK_REGISTRY,
instantiate_class,
LightningArgumentParser,
LightningCLI,
Expand Down Expand Up @@ -861,6 +862,11 @@ class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
pass


@CALLBACK_REGISTRY
class CustomCallback(Callback):
pass


def test_registries(tmpdir):
assert "SGD" in OPTIMIZER_REGISTRY.names
assert "RMSprop" in OPTIMIZER_REGISTRY.names
Expand All @@ -870,23 +876,41 @@ def test_registries(tmpdir):
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names

assert "EarlyStopping" in CALLBACK_REGISTRY.names
assert "CustomCallback" in CALLBACK_REGISTRY.names

with pytest.raises(MisconfigurationException, match="is already present in the registry"):
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)


def test_registries_resolution():
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks):
"""This test validates registries are used when simplified command line are being used."""
cli_args = [
"--optimizer",
"Adam",
"--optimizer.lr",
"0.0001",
"--trainer.callbacks=LearningRateMonitor",
"--trainer.callbacks.logging_interval=epoch",
"--trainer.callbacks.log_momentum=True",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=loss",
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"--lr_scheduler",
"StepLR",
"--lr_scheduler.step_size=50",
]

extras = []
if use_class_path_callbacks:
callbacks = [
{"class_path": "pytorch_lightning.callbacks.Callback"},
{"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}},
]
cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"]
extras = [Callback, Callback]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, run=False)

Expand All @@ -895,6 +919,80 @@ def test_registries_resolution():
assert optimizers[0].param_groups[0]["lr"] == 0.0001
assert lr_scheduler[0].step_size == 50

callback_types = [type(c) for c in cli.trainer.callbacks]
expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras
assert all(t in callback_types for t in expected)

tchaton marked this conversation as resolved.
Show resolved Hide resolved

def test_argv_transformation_noop():
base = ["any.py", "--trainer.max_epochs=1"]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base)
assert argv == base


def test_argv_transformation_single_callback():
base = ["any.py", "--trainer.max_epochs=1"]
input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
}
]
expected = base + ["--trainer.callbacks", str(callbacks)]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected


def test_argv_transformation_multiple_callbacks():
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=val_loss",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=val_acc",
]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
},
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_acc"},
},
]
expected = base + ["--trainer.callbacks", str(callbacks)]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected


def test_argv_transformation_multiple_callbacks_with_config():
base = ["any.py", "--trainer.max_epochs=1"]
nested_key = "trainer.callbacks"
input = base + [
f"--{nested_key}=ModelCheckpoint",
f"--{nested_key}.monitor=val_loss",
f"--{nested_key}=ModelCheckpoint",
f"--{nested_key}.monitor=val_acc",
f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]",
]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
},
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_acc"},
},
{"class_path": "pytorch_lightning.callbacks.Callback"},
]
expected = base + ["--trainer.callbacks", str(callbacks)]
nested_key = "trainer.callbacks"
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
assert argv == expected


@pytest.mark.parametrize(
["args", "expected", "nested_key", "registry"],
Expand Down