Skip to content

Commit

Permalink
[CLI] Shorthand notation to instantiate callbacks [3/3] (#8815)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
2 people authored and SeanNaren committed Sep 22, 2021
1 parent 16b2559 commit dc3441c
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/code-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ jobs:
run: |
grep mypy requirements/test.txt | xargs -0 pip install
pip list
- run: mypy
- run: mypy --install-types --non-interactive
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Support passing lists of callbacks via command line ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815))


- Fault-tolerant training:
Expand Down
62 changes: 58 additions & 4 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from unittest import mock
from typing import List
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, Callback


class NoFitTrainer(Trainer):
Expand Down Expand Up @@ -371,6 +371,59 @@ Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.tr
:class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured
the same way using :code:`class_path` and :code:`init_args`.
For callbacks in particular, Lightning simplifies the command line so that only
the :class:`~pytorch_lightning.callbacks.Callback` name is required.
The argument's order matters and the user needs to pass the arguments in the following way.
.. code-block:: bash
$ python ... \
--trainer.callbacks={CALLBACK_1_NAME} \
--trainer.callbacks.{CALLBACK_1_ARGS_1}=... \
--trainer.callbacks.{CALLBACK_1_ARGS_2}=... \
...
--trainer.callbacks={CALLBACK_N_NAME} \
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
...
Here is an example:
.. code-block:: bash
$ python ... \
--trainer.callbacks=EarlyStopping \
--trainer.callbacks.patience=5 \
--trainer.callbacks=LearningRateMonitor \
--trainer.callbacks.logging_interval=epoch
Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification
as described above:
.. code-block:: python
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY
@CALLBACK_REGISTRY
class CustomCallback(Callback):
...
cli = LightningCLI(...)
.. code-block:: bash
$ python ... --trainer.callbacks=CustomCallback ...
This callback will be included in the generated config:
.. code-block:: yaml
trainer:
callbacks:
- class_path: your_class_path.CustomCallback
init_args:
...
Multiple models and/or datasets
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -517,9 +570,10 @@ instantiating the trainer class can be found in :code:`self.config['fit']['train
Configurable callbacks
^^^^^^^^^^^^^^^^^^^^^^
As explained previously, any callback can be added by including it in the config via :code:`class_path` and
:code:`init_args` entries. However, there are other cases in which a callback should always be present and be
configurable. This can be implemented as follows:
As explained previously, any Lightning callback can be added by passing it through command line or
including it in the config via :code:`class_path` and :code:`init_args` entries.
However, there are other cases in which a callback should always be present and be configurable.
This can be implemented as follows:
.. testcode::
Expand Down
88 changes: 82 additions & 6 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
from unittest import mock

import torch
import yaml
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, rank_zero_warn, warnings
from pytorch_lightning.utilities.cloud_io import get_filesystem
Expand Down Expand Up @@ -83,12 +85,15 @@ def __str__(self) -> str:
LR_SCHEDULER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)

CALLBACK_REGISTRY = _Registry()
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback)


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""

# use class attribute because `parse_args` is only called on the main parser
_choices: Dict[str, Tuple[Type, ...]] = {}
_choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {}

def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.
Expand Down Expand Up @@ -202,23 +207,35 @@ def add_lr_scheduler_args(

def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
argv = sys.argv
for k, classes in self._choices.items():
for k, v in self._choices.items():
if not any(arg.startswith(f"--{k}") for arg in argv):
# the key wasn't passed - maybe defined in a config, maybe it's optional
continue
argv = self._convert_argv_issue_84(classes, k, argv)
classes, is_list = v
# knowing whether the argument is a list type automatically would be too complex
if is_list:
argv = self._convert_argv_issue_85(classes, k, argv)
else:
argv = self._convert_argv_issue_84(classes, k, argv)
self._choices.clear()
with mock.patch("sys.argv", argv):
return super().parse_args(*args, **kwargs)

def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None:
self._choices[nested_key] = classes
def set_choices(self, nested_key: str, classes: Tuple[Type, ...], is_list: bool = False) -> None:
"""Adds support for shorthand notation for a particular nested key.
Args:
nested_key: The key whose choices will be set.
classes: A tuple of classes to choose from.
is_list: Whether the argument is a ``List[object]`` type.
"""
self._choices[nested_key] = (classes, is_list)

@staticmethod
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.
This should be removed once implemented.
Adds support for shorthand notation for ``object`` arguments.
"""
passed_args, clean_argv = {}, []
argv_key = f"--{nested_key}"
Expand Down Expand Up @@ -259,6 +276,64 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis
raise ValueError(f"Could not generate a config for {repr(argv_class)}")
return clean_argv + [argv_key, config]

@staticmethod
def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/85.
Adds support for shorthand notation for ``List[object]`` arguments.
"""
passed_args, clean_argv = [], []
passed_configs = {}
argv_key = f"--{nested_key}"
# get the argv args for this nested key
i = 0
while i < len(argv):
arg = argv[i]
if arg.startswith(argv_key):
if "=" in arg:
key, value = arg.split("=")
else:
key = arg
i += 1
value = argv[i]
if "class_path" in value:
# the user passed a config as a dict
passed_configs[key] = yaml.safe_load(value)
else:
passed_args.append((key, value))
else:
clean_argv.append(arg)
i += 1
# generate the associated config file
config = []
i, n = 0, len(passed_args)
while i < n - 1:
ki, vi = passed_args[i]
# convert class name to class path
for cls in classes:
if cls.__name__ == vi:
cls_type = cls
break
else:
raise ValueError(f"Could not generate a config for {repr(vi)}")
config.append(_global_add_class_path(cls_type))
# get any init args
j = i + 1 # in case the j-loop doesn't run
for j in range(i + 1, n):
kj, vj = passed_args[j]
if ki == kj:
break
if kj.startswith(ki):
init_arg_name = kj.split(".")[-1]
config[-1]["init_args"][init_arg_name] = vj
i = j
# update at the end to preserve the order
for k, v in passed_configs.items():
config.extend(v)
if not config:
return clean_argv
return clean_argv + [argv_key, str(config)]


class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.
Expand Down Expand Up @@ -430,6 +505,7 @@ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> No
def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"""Adds arguments from the core classes to the parser."""
parser.add_lightning_class_args(self.trainer_class, "trainer")
parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True)
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
parser.set_defaults(trainer_defaults)
parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)
Expand Down
100 changes: 99 additions & 1 deletion tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import (
CALLBACK_REGISTRY,
instantiate_class,
LightningArgumentParser,
LightningCLI,
Expand Down Expand Up @@ -861,6 +862,11 @@ class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
pass


@CALLBACK_REGISTRY
class CustomCallback(Callback):
pass


def test_registries(tmpdir):
assert "SGD" in OPTIMIZER_REGISTRY.names
assert "RMSprop" in OPTIMIZER_REGISTRY.names
Expand All @@ -870,23 +876,41 @@ def test_registries(tmpdir):
assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names
assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names

assert "EarlyStopping" in CALLBACK_REGISTRY.names
assert "CustomCallback" in CALLBACK_REGISTRY.names

with pytest.raises(MisconfigurationException, match="is already present in the registry"):
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer)
OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True)


def test_registries_resolution():
@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks):
"""This test validates registries are used when simplified command line are being used."""
cli_args = [
"--optimizer",
"Adam",
"--optimizer.lr",
"0.0001",
"--trainer.callbacks=LearningRateMonitor",
"--trainer.callbacks.logging_interval=epoch",
"--trainer.callbacks.log_momentum=True",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=loss",
"--lr_scheduler",
"StepLR",
"--lr_scheduler.step_size=50",
]

extras = []
if use_class_path_callbacks:
callbacks = [
{"class_path": "pytorch_lightning.callbacks.Callback"},
{"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}},
]
cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"]
extras = [Callback, Callback]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(BoringModel, run=False)

Expand All @@ -895,6 +919,80 @@ def test_registries_resolution():
assert optimizers[0].param_groups[0]["lr"] == 0.0001
assert lr_scheduler[0].step_size == 50

callback_types = [type(c) for c in cli.trainer.callbacks]
expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras
assert all(t in callback_types for t in expected)


def test_argv_transformation_noop():
base = ["any.py", "--trainer.max_epochs=1"]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base)
assert argv == base


def test_argv_transformation_single_callback():
base = ["any.py", "--trainer.max_epochs=1"]
input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
}
]
expected = base + ["--trainer.callbacks", str(callbacks)]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected


def test_argv_transformation_multiple_callbacks():
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=val_loss",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=val_acc",
]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
},
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_acc"},
},
]
expected = base + ["--trainer.callbacks", str(callbacks)]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected


def test_argv_transformation_multiple_callbacks_with_config():
base = ["any.py", "--trainer.max_epochs=1"]
nested_key = "trainer.callbacks"
input = base + [
f"--{nested_key}=ModelCheckpoint",
f"--{nested_key}.monitor=val_loss",
f"--{nested_key}=ModelCheckpoint",
f"--{nested_key}.monitor=val_acc",
f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]",
]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
},
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_acc"},
},
{"class_path": "pytorch_lightning.callbacks.Callback"},
]
expected = base + ["--trainer.callbacks", str(callbacks)]
nested_key = "trainer.callbacks"
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
assert argv == expected


@pytest.mark.parametrize(
["args", "expected", "nested_key", "registry"],
Expand Down

0 comments on commit dc3441c

Please sign in to comment.