Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added args parameter to LightningCLI to ease running from within Python #14596

Merged
merged 14 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/source-pytorch/cli/lightning_cli_advanced_3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,33 @@ You can also pass the class path directly, for example, if the optimizer hasn't
--optimizer1.lr=0.01 \
--optimizer2=torch.optim.AdamW \
--optimizer2.lr=0.0001


Run from Python
^^^^^^^^^^^^^^^

Even though the :class:`~pytorch_lightning.cli.LightningCLI` class is designed to help in the implementation of command
line tools, for some use cases it is desired to run directly from Python. To support these use cases, the ``args``
parameter can be used, for example:

.. testcode::

cli = LightningCLI(MyModel, args=["--trainer.max_epochs=100", "--model.encoder_layers=24"])

All the features that are supported from the command line can be used when giving ``args`` as a list of strings. It is
also possible to provide to ``args`` a ``dict`` or `Namespace
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.Namespace>`__. For example:

.. testcode::

cli = LightningCLI(
MyModel,
args={
"trainer": {
"max_epochs": 100,
},
"model": {
"encoder_layers": 24,
},
},
)
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#14575](https://github.com/Lightning-AI/lightning/issues/14575))


- Added `args` parameter to `LightningCLI` to ease running from within Python ([#14596](https://github.com/PyTorchLightning/pytorch-lightning/pull/14596))


### Changed

- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892))
Expand Down
15 changes: 12 additions & 3 deletions src/pytorch_lightning/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None,
subclass_mode_model: bool = False,
subclass_mode_data: bool = False,
args: Union[List[str], Dict[str, Any], Namespace] = None,
run: bool = True,
auto_registry: bool = False,
) -> None:
Expand Down Expand Up @@ -300,6 +301,9 @@ def __init__(
subclass_mode_data: Whether datamodule can be any `subclass
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
of the given class.
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style
arguments can be given in a ``list``. Alternatively a structured config options can be given in a
``dict`` or ``jsonargparse.Namespace``.
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
auto_registry: Whether to automatically fill up the registries with all defined subclasses.
Expand Down Expand Up @@ -338,7 +342,7 @@ def __init__(
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
)
self.setup_parser(run, main_kwargs, subparser_kwargs)
self.parse_arguments(self.parser)
self.parse_arguments(self.parser, args)

self.subcommand = self.config["subcommand"] if run else None

Expand Down Expand Up @@ -472,9 +476,14 @@ def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None:
add_class_path = _add_class_path_generator(class_type)
parser.link_arguments(key, link_to, compute_fn=add_class_path)

def parse_arguments(self, parser: LightningArgumentParser) -> None:
def parse_arguments(
self, parser: LightningArgumentParser, args: Union[List[str], Dict[str, Any], Namespace]
) -> None:
"""Parses command line arguments and stores it in ``self.config``."""
self.config = parser.parse_args()
if isinstance(args, (dict, Namespace)):
self.config = parser.parse_object(args)
else:
self.config = parser.parse_args(args)

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""
Expand Down
18 changes: 16 additions & 2 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import os
import pickle
import sys
from argparse import Namespace
from contextlib import contextmanager, ExitStack, redirect_stdout
from io import StringIO
from typing import Callable, List, Optional, Union
Expand Down Expand Up @@ -53,7 +52,7 @@
from tests_pytorch.helpers.utils import no_warning_call

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
from jsonargparse import lazy_instance
from jsonargparse import lazy_instance, Namespace


@contextmanager
Expand Down Expand Up @@ -1558,3 +1557,18 @@ def test_pytorch_profiler_init_args():
init["record_shapes"] = unresolved.pop("record_shapes") # Test move to init_args
assert {k: cli.config.trainer.profiler.init_args[k] for k in init} == init
assert cli.config.trainer.profiler.dict_kwargs == unresolved


@pytest.mark.parametrize(
["args"],
[
(["--trainer.logger=False", "--model.foo=456"],),
({"trainer": {"logger": False}, "model": {"foo": 456}},),
(Namespace(trainer=Namespace(logger=False), model=Namespace(foo=456)),),
],
)
def test_lightning_cli_with_args_given(args):
cli = LightningCLI(TestModel, run=False, args=args)
assert isinstance(cli.model, TestModel)
assert cli.config.trainer.logger is False
assert cli.model.foo == 456