From 1bb5fccb710ebd45025a041a52cd8b9f34b03445 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 17 Sep 2021 18:54:06 +0100 Subject: [PATCH] [CLI] Shorthand notation to instantiate callbacks [3/3] (#8815) Co-authored-by: Carlos Mocholi --- .github/workflows/code-checks.yml | 2 +- CHANGELOG.md | 1 + docs/source/common/lightning_cli.rst | 62 +++++++++++++++-- pytorch_lightning/utilities/cli.py | 88 +++++++++++++++++++++-- tests/utilities/test_cli.py | 100 ++++++++++++++++++++++++++- 5 files changed, 241 insertions(+), 12 deletions(-) diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index d666f60597786..8bede9ea9ddda 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -19,4 +19,4 @@ jobs: run: | grep mypy requirements/test.txt | xargs -0 pip install pip list - - run: mypy + - run: mypy --install-types --non-interactive diff --git a/CHANGELOG.md b/CHANGELOG.md index 660f059ca33a2..3d2d725a1c0ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) * Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) * Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) + * Support passing lists of callbacks via command line ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) - Fault-tolerant training: diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index e664e36aa0a63..56451cbccf71f 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -5,7 +5,7 @@ from unittest import mock from typing import List import pytorch_lightning as pl - from pytorch_lightning import LightningModule, LightningDataModule, Trainer + from pytorch_lightning import LightningModule, LightningDataModule, Trainer, Callback class NoFitTrainer(Trainer): @@ -371,6 +371,59 @@ Similar to the callbacks, any arguments in :class:`~pytorch_lightning.trainer.tr :class:`~pytorch_lightning.core.datamodule.LightningDataModule` classes that have as type hint a class can be configured the same way using :code:`class_path` and :code:`init_args`. +For callbacks in particular, Lightning simplifies the command line so that only +the :class:`~pytorch_lightning.callbacks.Callback` name is required. +The argument's order matters and the user needs to pass the arguments in the following way. + +.. code-block:: bash + + $ python ... \ + --trainer.callbacks={CALLBACK_1_NAME} \ + --trainer.callbacks.{CALLBACK_1_ARGS_1}=... \ + --trainer.callbacks.{CALLBACK_1_ARGS_2}=... \ + ... + --trainer.callbacks={CALLBACK_N_NAME} \ + --trainer.callbacks.{CALLBACK_N_ARGS_1}=... \ + ... + +Here is an example: + +.. code-block:: bash + + $ python ... \ + --trainer.callbacks=EarlyStopping \ + --trainer.callbacks.patience=5 \ + --trainer.callbacks=LearningRateMonitor \ + --trainer.callbacks.logging_interval=epoch + +Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification +as described above: + +.. code-block:: python + + from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY + + + @CALLBACK_REGISTRY + class CustomCallback(Callback): + ... + + + cli = LightningCLI(...) + +.. code-block:: bash + + $ python ... --trainer.callbacks=CustomCallback ... + +This callback will be included in the generated config: + +.. code-block:: yaml + + trainer: + callbacks: + - class_path: your_class_path.CustomCallback + init_args: + ... Multiple models and/or datasets ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -517,9 +570,10 @@ instantiating the trainer class can be found in :code:`self.config['fit']['train Configurable callbacks ^^^^^^^^^^^^^^^^^^^^^^ -As explained previously, any callback can be added by including it in the config via :code:`class_path` and -:code:`init_args` entries. However, there are other cases in which a callback should always be present and be -configurable. This can be implemented as follows: +As explained previously, any Lightning callback can be added by passing it through command line or +including it in the config via :code:`class_path` and :code:`init_args` entries. +However, there are other cases in which a callback should always be present and be configurable. +This can be implemented as follows: .. testcode:: diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index b27a1a12caaad..d97ef9ccddebb 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -20,8 +20,10 @@ from unittest import mock import torch +import yaml from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, rank_zero_warn, warnings from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -83,12 +85,15 @@ def __str__(self) -> str: LR_SCHEDULER_REGISTRY = _Registry() LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) +CALLBACK_REGISTRY = _Registry() +CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.callbacks.Callback) + class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" # use class attribute because `parse_args` is only called on the main parser - _choices: Dict[str, Tuple[Type, ...]] = {} + _choices: Dict[str, Tuple[Tuple[Type, ...], bool]] = {} def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. @@ -202,23 +207,35 @@ def add_lr_scheduler_args( def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: argv = sys.argv - for k, classes in self._choices.items(): + for k, v in self._choices.items(): if not any(arg.startswith(f"--{k}") for arg in argv): # the key wasn't passed - maybe defined in a config, maybe it's optional continue - argv = self._convert_argv_issue_84(classes, k, argv) + classes, is_list = v + # knowing whether the argument is a list type automatically would be too complex + if is_list: + argv = self._convert_argv_issue_85(classes, k, argv) + else: + argv = self._convert_argv_issue_84(classes, k, argv) self._choices.clear() with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) - def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None: - self._choices[nested_key] = classes + def set_choices(self, nested_key: str, classes: Tuple[Type, ...], is_list: bool = False) -> None: + """Adds support for shorthand notation for a particular nested key. + + Args: + nested_key: The key whose choices will be set. + classes: A tuple of classes to choose from. + is_list: Whether the argument is a ``List[object]`` type. + """ + self._choices[nested_key] = (classes, is_list) @staticmethod def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: """Placeholder for https://github.com/omni-us/jsonargparse/issues/84. - This should be removed once implemented. + Adds support for shorthand notation for ``object`` arguments. """ passed_args, clean_argv = {}, [] argv_key = f"--{nested_key}" @@ -259,6 +276,64 @@ def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: Lis raise ValueError(f"Could not generate a config for {repr(argv_class)}") return clean_argv + [argv_key, config] + @staticmethod + def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: + """Placeholder for https://github.com/omni-us/jsonargparse/issues/85. + + Adds support for shorthand notation for ``List[object]`` arguments. + """ + passed_args, clean_argv = [], [] + passed_configs = {} + argv_key = f"--{nested_key}" + # get the argv args for this nested key + i = 0 + while i < len(argv): + arg = argv[i] + if arg.startswith(argv_key): + if "=" in arg: + key, value = arg.split("=") + else: + key = arg + i += 1 + value = argv[i] + if "class_path" in value: + # the user passed a config as a dict + passed_configs[key] = yaml.safe_load(value) + else: + passed_args.append((key, value)) + else: + clean_argv.append(arg) + i += 1 + # generate the associated config file + config = [] + i, n = 0, len(passed_args) + while i < n - 1: + ki, vi = passed_args[i] + # convert class name to class path + for cls in classes: + if cls.__name__ == vi: + cls_type = cls + break + else: + raise ValueError(f"Could not generate a config for {repr(vi)}") + config.append(_global_add_class_path(cls_type)) + # get any init args + j = i + 1 # in case the j-loop doesn't run + for j in range(i + 1, n): + kj, vj = passed_args[j] + if ki == kj: + break + if kj.startswith(ki): + init_arg_name = kj.split(".")[-1] + config[-1]["init_args"][init_arg_name] = vj + i = j + # update at the end to preserve the order + for k, v in passed_configs.items(): + config.extend(v) + if not config: + return clean_argv + return clean_argv + [argv_key, str(config)] + class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. @@ -430,6 +505,7 @@ def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> No def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Adds arguments from the core classes to the parser.""" parser.add_lightning_class_args(self.trainer_class, "trainer") + parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True) trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"} parser.set_defaults(trainer_defaults) parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 1cd12a33b7217..bff5d7e9111e4 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -34,6 +34,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import ( + CALLBACK_REGISTRY, instantiate_class, LightningArgumentParser, LightningCLI, @@ -861,6 +862,11 @@ class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR): pass +@CALLBACK_REGISTRY +class CustomCallback(Callback): + pass + + def test_registries(tmpdir): assert "SGD" in OPTIMIZER_REGISTRY.names assert "RMSprop" in OPTIMIZER_REGISTRY.names @@ -870,23 +876,41 @@ def test_registries(tmpdir): assert "CosineAnnealingWarmRestarts" in LR_SCHEDULER_REGISTRY.names assert "CustomCosineAnnealingLR" in LR_SCHEDULER_REGISTRY.names + assert "EarlyStopping" in CALLBACK_REGISTRY.names + assert "CustomCallback" in CALLBACK_REGISTRY.names + with pytest.raises(MisconfigurationException, match="is already present in the registry"): OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer) OPTIMIZER_REGISTRY.register_classes(torch.optim, torch.optim.Optimizer, override=True) -def test_registries_resolution(): +@pytest.mark.parametrize("use_class_path_callbacks", [False, True]) +def test_registries_resolution(use_class_path_callbacks): """This test validates registries are used when simplified command line are being used.""" cli_args = [ "--optimizer", "Adam", "--optimizer.lr", "0.0001", + "--trainer.callbacks=LearningRateMonitor", + "--trainer.callbacks.logging_interval=epoch", + "--trainer.callbacks.log_momentum=True", + "--trainer.callbacks=ModelCheckpoint", + "--trainer.callbacks.monitor=loss", "--lr_scheduler", "StepLR", "--lr_scheduler.step_size=50", ] + extras = [] + if use_class_path_callbacks: + callbacks = [ + {"class_path": "pytorch_lightning.callbacks.Callback"}, + {"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}}, + ] + cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"] + extras = [Callback, Callback] + with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(BoringModel, run=False) @@ -895,6 +919,80 @@ def test_registries_resolution(): assert optimizers[0].param_groups[0]["lr"] == 0.0001 assert lr_scheduler[0].step_size == 50 + callback_types = [type(c) for c in cli.trainer.callbacks] + expected = [LearningRateMonitor, SaveConfigCallback, ModelCheckpoint] + extras + assert all(t in callback_types for t in expected) + + +def test_argv_transformation_noop(): + base = ["any.py", "--trainer.max_epochs=1"] + argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base) + assert argv == base + + +def test_argv_transformation_single_callback(): + base = ["any.py", "--trainer.max_epochs=1"] + input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"] + callbacks = [ + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_loss"}, + } + ] + expected = base + ["--trainer.callbacks", str(callbacks)] + argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) + assert argv == expected + + +def test_argv_transformation_multiple_callbacks(): + base = ["any.py", "--trainer.max_epochs=1"] + input = base + [ + "--trainer.callbacks=ModelCheckpoint", + "--trainer.callbacks.monitor=val_loss", + "--trainer.callbacks=ModelCheckpoint", + "--trainer.callbacks.monitor=val_acc", + ] + callbacks = [ + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_loss"}, + }, + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_acc"}, + }, + ] + expected = base + ["--trainer.callbacks", str(callbacks)] + argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input) + assert argv == expected + + +def test_argv_transformation_multiple_callbacks_with_config(): + base = ["any.py", "--trainer.max_epochs=1"] + nested_key = "trainer.callbacks" + input = base + [ + f"--{nested_key}=ModelCheckpoint", + f"--{nested_key}.monitor=val_loss", + f"--{nested_key}=ModelCheckpoint", + f"--{nested_key}.monitor=val_acc", + f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]", + ] + callbacks = [ + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_loss"}, + }, + { + "class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint", + "init_args": {"monitor": "val_acc"}, + }, + {"class_path": "pytorch_lightning.callbacks.Callback"}, + ] + expected = base + ["--trainer.callbacks", str(callbacks)] + nested_key = "trainer.callbacks" + argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input) + assert argv == expected + @pytest.mark.parametrize( ["args", "expected", "nested_key", "registry"],