Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CLI shorthand natively #12614

Merged
merged 10 commits into from
May 3, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Marked `swa_lrs` argument in `StochasticWeightAveraging` callback as required ([#12556](https://github.com/PyTorchLightning/pytorch-lightning/pull/12556))


-
- `LightningCLI`'s shorthand notation changed to use jsonargparse native feature ([#12614](https://github.com/PyTorchLightning/pytorch-lightning/pull/12614))


-
Expand Down
78 changes: 3 additions & 75 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def add_optimizer_args(
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
if isinstance(optimizer_class, tuple):
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
self.set_choices(nested_key, optimizer_class)
else:
self.add_class_arguments(optimizer_class, nested_key, sub_configs=True, **kwargs)
self._optimizers[nested_key] = (optimizer_class, link_to)
Expand All @@ -246,7 +245,6 @@ def add_lr_scheduler_args(
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
if isinstance(lr_scheduler_class, tuple):
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
self.set_choices(nested_key, lr_scheduler_class)
else:
self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
Expand All @@ -261,8 +259,6 @@ def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
# knowing whether the argument is a list type automatically would be too complex
if is_list:
argv = self._convert_argv_issue_85(classes, k, argv)
else:
argv = self._convert_argv_issue_84(classes, k, argv)
self._choices.clear()
with mock.patch("sys.argv", argv):
return super().parse_args(*args, **kwargs)
Expand All @@ -277,69 +273,6 @@ def set_choices(self, nested_key: str, classes: Tuple[Type, ...], is_list: bool
"""
self._choices[nested_key] = (classes, is_list)

@staticmethod
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.

Adds support for shorthand notation for ``object`` arguments.
"""
passed_args, clean_argv = {}, []
argv_key = f"--{nested_key}"
# get the argv args for this nested key
i = 0
while i < len(argv):
arg = argv[i]
if arg.startswith(argv_key):
if "=" in arg:
key, value = arg.split("=")
else:
key = arg
i += 1
value = argv[i]
passed_args[key] = value
else:
clean_argv.append(arg)
i += 1

# the user requested a help message
help_key = argv_key + ".help"
if help_key in passed_args:
argv_class = passed_args[help_key]
if "." in argv_class:
# user passed the class path directly
class_path = argv_class
else:
# convert shorthand format to the classpath
for cls in classes:
if cls.__name__ == argv_class:
class_path = _class_path_from_class(cls)
break
else:
raise ValueError(f"Could not generate get the class_path for {repr(argv_class)}")
return clean_argv + [help_key, class_path]

# generate the associated config file
argv_class = passed_args.pop(argv_key, "")
if not argv_class:
# the user passed a config as a str
class_path = passed_args[f"{argv_key}.class_path"]
init_args_key = f"{argv_key}.init_args"
init_args = {k[len(init_args_key) + 1 :]: v for k, v in passed_args.items() if k.startswith(init_args_key)}
config = str({"class_path": class_path, "init_args": init_args})
elif argv_class.startswith("{") or argv_class in ("None", "True", "False"):
# the user passed a config as a dict
config = argv_class
else:
# the user passed the shorthand format
init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period
for cls in classes:
if cls.__name__ == argv_class:
config = str(_global_add_class_path(cls, init_args))
break
else:
raise ValueError(f"Could not generate a config for {repr(argv_class)}")
return clean_argv + [argv_key, config]

@staticmethod
def _convert_argv_issue_85(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/85.
Expand Down Expand Up @@ -602,23 +535,18 @@ def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"""Adds arguments from the core classes to the parser."""
parser.add_lightning_class_args(self.trainer_class, "trainer")
parser.set_choices("trainer.callbacks", CALLBACK_REGISTRY.classes, is_list=True)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
parser.set_choices("trainer.logger", LOGGER_REGISTRY.classes)
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
parser.set_defaults(trainer_defaults)

parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model)
if self.model_class is None and len(MODEL_REGISTRY):
# did not pass a model and there are models registered
parser.set_choices("model", MODEL_REGISTRY.classes)

if self.datamodule_class is not None:
parser.add_lightning_class_args(self._datamodule_class, "data", subclass_mode=self.subclass_mode_data)
elif len(DATAMODULE_REGISTRY):
else:
# this should not be required because the user might want to use the `LightningModule` dataloaders
parser.add_lightning_class_args(
self._datamodule_class, "data", subclass_mode=self.subclass_mode_data, required=False
)
parser.set_choices("data", DATAMODULE_REGISTRY.classes)

def _add_arguments(self, parser: LightningArgumentParser) -> None:
# default + core + custom arguments
Expand All @@ -627,9 +555,9 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None:
self.add_arguments_to_parser(parser)
# add default optimizer args if necessary
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes)
parser.add_optimizer_args((Optimizer,))
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes)
parser.add_lr_scheduler_args(LRSchedulerTypeTuple)
self.link_optimizers_and_lr_schedulers(parser)

def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
Expand Down
96 changes: 30 additions & 66 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pickle
import sys
from argparse import Namespace
from contextlib import redirect_stdout
from contextlib import contextmanager, ExitStack, redirect_stdout
from io import StringIO
from typing import List, Optional, Union
from unittest import mock
Expand Down Expand Up @@ -46,6 +46,7 @@
LightningCLI,
LOGGER_REGISTRY,
LR_SCHEDULER_REGISTRY,
LRSchedulerTypeTuple,
MODEL_REGISTRY,
OPTIMIZER_REGISTRY,
SaveConfigCallback,
Expand All @@ -61,6 +62,17 @@
torchvision_version = version.parse(__import__("torchvision").__version__)


@contextmanager
def mock_subclasses(baseclass, *subclasses):
"""Mocks baseclass so that it only has the given child subclasses."""
with ExitStack() as stack:
mgr = mock.patch.object(baseclass, "__subclasses__", return_value=[*subclasses])
stack.enter_context(mgr)
for mgr in [mock.patch.object(s, "__subclasses__", return_value=[]) for s in subclasses]:
stack.enter_context(mgr)
yield None


@mock.patch("argparse.ArgumentParser.parse_args")
def test_default_args(mock_argparse):
"""Tests default argument parser for Trainer."""
Expand Down Expand Up @@ -725,18 +737,18 @@ def add_arguments_to_parser(self, parser):
assert cli.trainer.lr_scheduler_configs[0].scheduler.step_size == 50


@pytest.mark.parametrize("use_registries", [False, True])
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_registries, tmpdir):
@pytest.mark.parametrize("use_generic_base_class", [False, True])
def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base_class, tmpdir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes if use_registries else torch.optim.Adam,
(torch.optim.Optimizer,) if use_generic_base_class else torch.optim.Adam,
nested_key="optim1",
link_to="model.optim1",
)
parser.add_optimizer_args((torch.optim.ASGD, torch.optim.SGD), nested_key="optim2", link_to="model.optim2")
parser.add_lr_scheduler_args(
LR_SCHEDULER_REGISTRY.classes if use_registries else torch.optim.lr_scheduler.ExponentialLR,
LRSchedulerTypeTuple if use_generic_base_class else torch.optim.lr_scheduler.ExponentialLR,
link_to="model.scheduler",
)

Expand All @@ -748,7 +760,7 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
self.scheduler = instantiate_class(self.optim1, scheduler)

cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1"]
if use_registries:
if use_generic_base_class:
cli_args += [
"--optim1",
"Adam",
Expand All @@ -759,7 +771,7 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
"--lr_scheduler=ExponentialLR",
]
else:
cli_args += ["--optim2.class_path=torch.optim.SGD", "--optim2.init_args.lr=0.01"]
cli_args += ["--optim2=SGD", "--optim2.lr=0.01"]
cli_args += ["--lr_scheduler.gamma=0.2"]

with mock.patch("sys.argv", ["any.py"] + cli_args):
Expand Down Expand Up @@ -964,18 +976,17 @@ def __init__(self, foo, bar=5):
self.bar = bar


def test_lightning_cli_model_choices():
MODEL_REGISTRY(cls=TestModel)
MODEL_REGISTRY(cls=BoringModel)

def test_lightning_cli_model_short_arguments():
with mock.patch("sys.argv", ["any.py", "fit", "--model=BoringModel"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
) as run:
) as run, mock_subclasses(LightningModule, BoringModel, TestModel):
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
run.assert_called_once_with(cli.model, ANY, ANY, ANY, ANY)

with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]):
with mock.patch("sys.argv", ["any.py", "--model=TestModel", "--model.foo", "123"]), mock_subclasses(
LightningModule, BoringModel, TestModel
):
cli = LightningCLI(run=False)
assert isinstance(cli.model, TestModel)
assert cli.model.foo == 123
Expand All @@ -989,11 +1000,7 @@ def __init__(self, foo, bar=5):
self.bar = bar


def test_lightning_cli_datamodule_choices():
MODEL_REGISTRY(cls=BoringModel)
DATAMODULE_REGISTRY(cls=MyDataModule)
DATAMODULE_REGISTRY(cls=BoringDataModule)

def test_lightning_cli_datamodule_short_arguments():
# with set model
with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
Expand All @@ -1011,30 +1018,25 @@ def test_lightning_cli_datamodule_choices():
# with configurable model
with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch(
"pytorch_lightning.Trainer._fit_impl"
) as run:
) as run, mock_subclasses(LightningModule, BoringModel):
cli = LightningCLI(trainer_defaults={"fast_dev_run": 1})
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, BoringDataModule)
run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule, ANY)

with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]):
with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]), mock_subclasses(
LightningModule, BoringModel
):
cli = LightningCLI(run=False)
assert isinstance(cli.model, BoringModel)
assert isinstance(cli.datamodule, MyDataModule)

assert len(DATAMODULE_REGISTRY) # needs a value initially added
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, run=False)
# data was not passed but we are adding it automatically because there are datamodules registered
assert "data" in cli.parser.groups
assert not hasattr(cli.parser.groups["data"], "group_class")

with mock.patch("sys.argv", ["any.py"]), mock.patch.dict(DATAMODULE_REGISTRY, clear=True):
cli = LightningCLI(BoringModel, run=False, auto_registry=False)
# no registered classes so not added automatically
assert "data" not in cli.parser.groups
assert len(DATAMODULE_REGISTRY) # check state was not modified

with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(BoringModel, BoringDataModule, run=False)
# since we are passing the DataModule, that's whats added to the parser
Expand All @@ -1043,7 +1045,6 @@ def test_lightning_cli_datamodule_choices():

@pytest.mark.parametrize("use_class_path_callbacks", [False, True])
def test_registries_resolution(use_class_path_callbacks):
MODEL_REGISTRY(cls=BoringModel)

"""This test validates registries are used when simplified command line are being used."""
cli_args = [
Expand Down Expand Up @@ -1071,7 +1072,7 @@ def test_registries_resolution(use_class_path_callbacks):
cli_args += [f"--trainer.callbacks={json.dumps(callbacks)}"]
extras = [Callback, Callback]

with mock.patch("sys.argv", ["any.py"] + cli_args):
with mock.patch("sys.argv", ["any.py"] + cli_args), mock_subclasses(LightningModule, BoringModel):
cli = LightningCLI(run=False)

assert isinstance(cli.model, BoringModel)
Expand Down Expand Up @@ -1158,43 +1159,6 @@ def test_argv_transformation_multiple_callbacks_with_config():
assert argv == expected


@pytest.mark.parametrize(
["args", "expected", "nested_key", "registry"],
[
(
["--optimizer", "Adadelta"],
{"class_path": "torch.optim.adadelta.Adadelta", "init_args": {}},
"optimizer",
OPTIMIZER_REGISTRY,
),
(
["--optimizer", "Adadelta", "--optimizer.lr", "10"],
{"class_path": "torch.optim.adadelta.Adadelta", "init_args": {"lr": "10"}},
"optimizer",
OPTIMIZER_REGISTRY,
),
(
["--lr_scheduler", "OneCycleLR"],
{"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {}},
"lr_scheduler",
LR_SCHEDULER_REGISTRY,
),
(
["--lr_scheduler", "OneCycleLR", "--lr_scheduler.anneal_strategy=linear"],
{"class_path": "torch.optim.lr_scheduler.OneCycleLR", "init_args": {"anneal_strategy": "linear"}},
"lr_scheduler",
LR_SCHEDULER_REGISTRY,
),
],
)
def test_argv_transformations_with_optimizers_and_lr_schedulers(args, expected, nested_key, registry):
base = ["any.py", "--trainer.max_epochs=1"]
argv = base + args
_populate_registries(False)
new_argv = LightningArgumentParser._convert_argv_issue_84(registry.classes, nested_key, argv)
assert new_argv == base + [f"--{nested_key}", str(expected)]


def test_optimizers_and_lr_schedulers_reload(tmpdir):
base = ["any.py", "--trainer.max_epochs=1"]
input = base + [
Expand Down