Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningCLI natively support callback list append #13129

Merged
merged 9 commits into from
Jun 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `pytorch_lightning.core.lightning` to `pytorch_lightning.core.module` ([#12740](https://github.com/PyTorchLightning/pytorch-lightning/pull/12740))


- `LightningCLI` changed to use jsonargparse native support for list append ([#13129](https://github.com/PyTorchLightning/pytorch-lightning/pull/13129))


-

### Deprecated
Expand Down
14 changes: 7 additions & 7 deletions docs/source/cli/lightning_cli_advanced_3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ The argument's order matters and the user needs to pass the arguments in the fol
.. code-block:: bash

$ python ... \
--trainer.callbacks={CALLBACK_1_NAME} \
--trainer.callbacks+={CALLBACK_1_NAME} \
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved
--trainer.callbacks.{CALLBACK_1_ARGS_1}=... \
--trainer.callbacks.{CALLBACK_1_ARGS_2}=... \
...
--trainer.callbacks={CALLBACK_N_NAME} \
--trainer.callbacks+={CALLBACK_N_NAME} \
--trainer.callbacks.{CALLBACK_N_ARGS_1}=... \
...

Expand All @@ -132,9 +132,9 @@ Here is an example:
.. code-block:: bash

$ python ... \
--trainer.callbacks=EarlyStopping \
--trainer.callbacks+=EarlyStopping \
--trainer.callbacks.patience=5 \
--trainer.callbacks=LearningRateMonitor \
--trainer.callbacks+=LearningRateMonitor \
--trainer.callbacks.logging_interval=epoch

Lightning provides a mechanism for you to add your own callbacks and benefit from the command line simplification
Expand All @@ -154,12 +154,12 @@ as described above:

.. code-block:: bash

$ python ... --trainer.callbacks=CustomCallback ...
$ python ... --trainer.callbacks+=CustomCallback ...

.. note::

This shorthand notation is only supported in the shell and not inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
This shorthand notation is also supported inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the full ``class_path`` notation.

.. code-block:: yaml

Expand Down
74 changes: 1 addition & 73 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@

import inspect
import os
import sys
from functools import partial, update_wrapper
from types import MethodType, ModuleType
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union
from unittest import mock

import torch
import yaml
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand All @@ -35,13 +32,11 @@
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.seed import _select_seed_randomly

_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.7.1")
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.8.0.dev1")

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode
from jsonargparse.typehints import get_all_subclass_paths
from jsonargparse.util import import_object

set_config_read_mode(fsspec_enabled=True)
else:
Expand Down Expand Up @@ -255,73 +250,6 @@ def add_lr_scheduler_args(
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)

def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
argv = sys.argv
nested_key = "trainer.callbacks"
if any(arg.startswith(f"--{nested_key}") for arg in argv):
classes = tuple(import_object(x) for x in get_all_subclass_paths(Callback))
argv = self._convert_argv_issue_85(classes, nested_key, argv)
with mock.patch("sys.argv", argv):
return super().parse_args(*args, **kwargs)

@staticmethod
def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/85.

Adds support for shorthand notation for ``List[object]`` arguments.
"""
passed_args, clean_argv = [], []
passed_configs = {}
argv_key = f"--{nested_key}"
# get the argv args for this nested key
i = 0
while i < len(argv):
arg = argv[i]
if arg.startswith(argv_key):
if "=" in arg:
key, value = arg.split("=")
else:
key = arg
i += 1
value = argv[i]
if "class_path" in value:
# the user passed a config as a dict
passed_configs[key] = yaml.safe_load(value)
else:
passed_args.append((key, value))
else:
clean_argv.append(arg)
i += 1
# generate the associated config file
config = []
i, n = 0, len(passed_args)
while i < n - 1:
ki, vi = passed_args[i]
# convert class name to class path
for cls in classes:
if cls.__name__ == vi:
cls_type = cls
break
else:
raise ValueError(f"Could not generate a config for {repr(vi)}")
config.append(_global_add_class_path(cls_type))
# get any init args
j = i + 1 # in case the j-loop doesn't run
for j in range(i + 1, n):
kj, vj = passed_args[j]
if ki == kj:
break
if kj.startswith(ki):
init_arg_name = kj.split(".")[-1]
config[-1]["init_args"][init_arg_name] = vj
i = j
# update at the end to preserve the order
for k, v in passed_configs.items():
config.extend(v)
if not config:
return clean_argv
return clean_argv + [argv_key, str(config)]


class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ matplotlib>3.1, <=3.5.1
torchtext>=0.9.*, <=0.12.0
omegaconf>=2.0.5, <=2.1.*
hydra-core>=1.0.5, <=1.1.*
jsonargparse[signatures]>=4.7.1, <=4.7.1
jsonargparse[signatures]>=4.8.0.dev1
gcsfs>=2021.5.0, <=2022.2.0
rich>=10.2.2,!=10.15.*, <=12.0.0
82 changes: 5 additions & 77 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,19 +1046,20 @@ def test_lightning_cli_datamodule_short_arguments():


@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks):
def test_callbacks_append(use_class_path_callbacks):

"""This test validates registries are used when simplified command line are being used."""
cli_args = [
"--optimizer",
"Adam",
"--optimizer.lr",
"0.0001",
"--trainer.callbacks=LearningRateMonitor",
"--trainer.callbacks+=LearningRateMonitor",
"--trainer.callbacks.logging_interval=epoch",
"--trainer.callbacks.log_momentum=True",
"--model=BoringModel",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks+",
"ModelCheckpoint",
"--trainer.callbacks.monitor=loss",
"--lr_scheduler",
"StepLR",
Expand All @@ -1071,7 +1072,7 @@ def test_registries_resolution(use_class_path_callbacks):
{"class_path": "pytorch_lightning.callbacks.Callback"},
{"class_path": "pytorch_lightning.callbacks.Callback", "init_args": {}},
]
cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"]
cli_args += [f"--trainer.callbacks+={json.dumps(callbacks)}"]
extras = [Callback, Callback]

with mock.patch("sys.argv", ["any.py"] + cli_args), mock_subclasses(LightningModule, BoringModel):
Expand All @@ -1088,79 +1089,6 @@ def test_registries_resolution(use_class_path_callbacks):
assert all(t in callback_types for t in expected)


def test_argv_transformation_noop():
base = ["any.py", "--trainer.max_epochs=1"]
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", base)
assert argv == base


def test_argv_transformation_single_callback():
base = ["any.py", "--trainer.max_epochs=1"]
input = base + ["--trainer.callbacks=ModelCheckpoint", "--trainer.callbacks.monitor=val_loss"]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
}
]
expected = base + ["--trainer.callbacks", str(callbacks)]
_populate_registries(False)
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected


def test_argv_transformation_multiple_callbacks():
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=val_loss",
"--trainer.callbacks=ModelCheckpoint",
"--trainer.callbacks.monitor=val_acc",
]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
},
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_acc"},
},
]
expected = base + ["--trainer.callbacks", str(callbacks)]
_populate_registries(False)
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, "trainer.callbacks", input)
assert argv == expected


def test_argv_transformation_multiple_callbacks_with_config():
base = ["any.py", "--trainer.max_epochs=1"]
nested_key = "trainer.callbacks"
input = base + [
f"--{nested_key}=ModelCheckpoint",
f"--{nested_key}.monitor=val_loss",
f"--{nested_key}=ModelCheckpoint",
f"--{nested_key}.monitor=val_acc",
f"--{nested_key}=[{{'class_path': 'pytorch_lightning.callbacks.Callback'}}]",
]
callbacks = [
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_loss"},
},
{
"class_path": "pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint",
"init_args": {"monitor": "val_acc"},
},
{"class_path": "pytorch_lightning.callbacks.Callback"},
]
expected = base + ["--trainer.callbacks", str(callbacks)]
nested_key = "trainer.callbacks"
_populate_registries(False)
argv = LightningArgumentParser._convert_argv_issue_85(CALLBACK_REGISTRY.classes, nested_key, input)
assert argv == expected


def test_optimizers_and_lr_schedulers_reload(tmpdir):
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
Expand Down