From 118a121c580f9081981cf0ca242fb157b1d5b678 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Thu, 8 Sep 2022 07:57:41 +0200 Subject: [PATCH 1/8] Added args parameter to LightningCLI to ease running from within Python. --- .../cli/lightning_cli_advanced_3.rst | 30 +++++++++++++++++++ src/pytorch_lightning/CHANGELOG.md | 3 ++ src/pytorch_lightning/cli.py | 13 ++++++-- tests/tests_pytorch/test_cli.py | 17 +++++++++-- 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst index df062061022c9..475adfa631a7f 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst @@ -354,3 +354,33 @@ You can also pass the class path directly, for example, if the optimizer hasn't --optimizer1.lr=0.01 \ --optimizer2=torch.optim.AdamW \ --optimizer2.lr=0.0001 + + +Run from Python +^^^^^^^^^^^^^^^ + +Even though the :class:`~pytorch_lightning.cli.LightningCLI` class is designed to help in the implementation of command +line tools, for some use cases it is desired to run directly from Python. To support these use cases, the ``args`` +parameter can be used, for example: + +.. testcode:: + + cli = LightningCLI(MyModel, args=["--trainer.max_epochs=100", "--model.encoder_layers=24"]) + +All the features that are supported from the command line can be used when giving ``args`` as a list of strings. It is +also possible to provide to ``args`` a ``dict`` or `Namespace +`__. For example: + +.. testcode:: + + cli = LightningCLI( + MyModel, + args={ + "trainer": { + "max_epochs": 100, + }, + "model": { + "encoder_layers": 24, + }, + }, + ) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ecf1ce319aa13..37dbec442937f 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#14575](https://github.com/Lightning-AI/lightning/issues/14575)) +- Added `args` parameter to `LightningCLI` to ease running from within Python ([#14596](https://github.com/PyTorchLightning/pytorch-lightning/pull/14596)) + + ### Changed - The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892)) diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index 82156c6b4ab90..b821764ae514d 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -256,6 +256,7 @@ def __init__( parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, + args: Union[List[str], Dict[str, Any], Namespace] = None, run: bool = True, auto_registry: bool = False, ) -> None: @@ -300,6 +301,9 @@ def __init__( subclass_mode_data: Whether datamodule can be any `subclass `_ of the given class. + args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style + arguments can be given in a ``list``. Alternatively a structured config options can be given in a + ``dict`` or ``jsonargparse.Namespace``. run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` method. If set to ``False``, the trainer and model classes will be instantiated only. auto_registry: Whether to automatically fill up the registries with all defined subclasses. @@ -338,7 +342,7 @@ def __init__( {"description": description, "env_prefix": env_prefix, "default_env": env_parse}, ) self.setup_parser(run, main_kwargs, subparser_kwargs) - self.parse_arguments(self.parser) + self.parse_arguments(self.parser, args) self.subcommand = self.config["subcommand"] if run else None @@ -472,9 +476,12 @@ def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: add_class_path = _add_class_path_generator(class_type) parser.link_arguments(key, link_to, compute_fn=add_class_path) - def parse_arguments(self, parser: LightningArgumentParser) -> None: + def parse_arguments(self, parser: LightningArgumentParser, args: Union[List[str], Dict[str, Any], Namespace]) -> None: """Parses command line arguments and stores it in ``self.config``.""" - self.config = parser.parse_args() + if isinstance(args, (dict, Namespace)): + self.config = parser.parse_object(args) + else: + self.config = parser.parse_args(args) def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes.""" diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 46fc7e9b6217f..f9bf92f94705d 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -16,7 +16,6 @@ import os import pickle import sys -from argparse import Namespace from contextlib import contextmanager, ExitStack, redirect_stdout from io import StringIO from typing import Callable, List, Optional, Union @@ -53,7 +52,7 @@ from tests_pytorch.helpers.utils import no_warning_call if _JSONARGPARSE_SIGNATURES_AVAILABLE: - from jsonargparse import lazy_instance + from jsonargparse import lazy_instance, Namespace @contextmanager @@ -1558,3 +1557,17 @@ def test_pytorch_profiler_init_args(): init["record_shapes"] = unresolved.pop("record_shapes") # Test move to init_args assert {k: cli.config.trainer.profiler.init_args[k] for k in init} == init assert cli.config.trainer.profiler.dict_kwargs == unresolved + + +@pytest.mark.parametrize(["args"], + [ + (["--trainer.logger=False", "--model.foo=456"], ), + ({"trainer": {"logger": False}, "model": {"foo": 456}}, ), + (Namespace(trainer=Namespace(logger=False), model=Namespace(foo=456)), ), + ], +) +def test_lightning_cli_with_args_given(args): + cli = LightningCLI(TestModel, run=False, args=args) + assert isinstance(cli.model, TestModel) + assert cli.config.trainer.logger is False + assert cli.model.foo == 456 From a726368919a924194106cd9bbf6e3967c5deb23a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 Sep 2022 06:21:16 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/cli.py | 4 +++- tests/tests_pytorch/test_cli.py | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index b821764ae514d..957bdd188c3d2 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -476,7 +476,9 @@ def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: add_class_path = _add_class_path_generator(class_type) parser.link_arguments(key, link_to, compute_fn=add_class_path) - def parse_arguments(self, parser: LightningArgumentParser, args: Union[List[str], Dict[str, Any], Namespace]) -> None: + def parse_arguments( + self, parser: LightningArgumentParser, args: Union[List[str], Dict[str, Any], Namespace] + ) -> None: """Parses command line arguments and stores it in ``self.config``.""" if isinstance(args, (dict, Namespace)): self.config = parser.parse_object(args) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index f9bf92f94705d..28fe7d1fb8d3f 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1559,11 +1559,12 @@ def test_pytorch_profiler_init_args(): assert cli.config.trainer.profiler.dict_kwargs == unresolved -@pytest.mark.parametrize(["args"], +@pytest.mark.parametrize( + ["args"], [ - (["--trainer.logger=False", "--model.foo=456"], ), - ({"trainer": {"logger": False}, "model": {"foo": 456}}, ), - (Namespace(trainer=Namespace(logger=False), model=Namespace(foo=456)), ), + (["--trainer.logger=False", "--model.foo=456"],), + ({"trainer": {"logger": False}, "model": {"foo": 456}},), + (Namespace(trainer=Namespace(logger=False), model=Namespace(foo=456)),), ], ) def test_lightning_cli_with_args_given(args): From 6614c296002deafc5feca5e675702d958c2cafa2 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 12 Sep 2022 21:27:31 +0200 Subject: [PATCH 3/8] fix --- tests/tests_pytorch/test_cli.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index c040bfaf59835..95891e214c733 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -14,8 +14,6 @@ import inspect import json import os -import pickle -import sys from contextlib import contextmanager, ExitStack, redirect_stdout from io import StringIO from typing import Callable, List, Optional, Union From d7cd656f787d43729a309bd73a6e7ac69c1d80ad Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Mon, 19 Sep 2022 07:05:49 +0200 Subject: [PATCH 4/8] Address comments in pull request. --- .../cli/lightning_cli_advanced_3.rst | 54 +++++++++++++------ .../cli/lightning_cli_intermediate.rst | 8 ++- src/pytorch_lightning/cli.py | 18 +++++-- tests/tests_pytorch/test_cli.py | 10 +++- 4 files changed, 66 insertions(+), 24 deletions(-) diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst index 475adfa631a7f..7a89519928def 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst @@ -360,27 +360,49 @@ Run from Python ^^^^^^^^^^^^^^^ Even though the :class:`~pytorch_lightning.cli.LightningCLI` class is designed to help in the implementation of command -line tools, for some use cases it is desired to run directly from Python. To support these use cases, the ``args`` -parameter can be used, for example: +line tools, for some use cases it is desired to run directly from Python. To allow this there is the ``args`` parameter. +An example could be to first implement a normal CLI script, but adding an ``args`` parameter with default ``None`` to +the main function as follows: -.. testcode:: +.. code:: python + + from pytorch_lightning.cli import ArgsType, LightningCLI + + def cli_main(args: ArgsType = None): + cli = LightningCLI(MyModel, ..., args=args) + ... + + if __name__ == "__main__": + cli_main() + +Then it is possible to import the ``cli_main`` function to run it. Executing in a shell ``my_cli.py +--trainer.max_epochs=100", "--model.encoder_layers=24`` would be equivalent to: + +.. code:: python - cli = LightningCLI(MyModel, args=["--trainer.max_epochs=100", "--model.encoder_layers=24"]) + from my_module.my_cli import cli_main + + cli_main(["--trainer.max_epochs=100", "--model.encoder_layers=24"]) All the features that are supported from the command line can be used when giving ``args`` as a list of strings. It is -also possible to provide to ``args`` a ``dict`` or `Namespace +also possible to provide a ``dict`` or `jsonargparse.Namespace `__. For example: -.. testcode:: +.. code:: python - cli = LightningCLI( - MyModel, - args={ - "trainer": { - "max_epochs": 100, - }, - "model": { - "encoder_layers": 24, - }, + args = { + "trainer": { + "max_epochs": 100, }, - ) + "model": {}, + } + + for encoder_layers in [8, 16, 24]: + args["model"]["encoder_layers"] = encoder_layers + cli_main(args) + +.. note:: + + The ``args`` parameter must be ``None`` when running from command line so that ``sys.argv`` is used as arguments. + Also, note that the purpose of ``trainer_defaults`` is different to ``args``. It is okay to use ``trainer_defaults`` + in the ``cli_main`` function to modify the defaults of some trainer parameters. diff --git a/docs/source-pytorch/cli/lightning_cli_intermediate.rst b/docs/source-pytorch/cli/lightning_cli_intermediate.rst index 6ed4921305c0d..df637715f925b 100644 --- a/docs/source-pytorch/cli/lightning_cli_intermediate.rst +++ b/docs/source-pytorch/cli/lightning_cli_intermediate.rst @@ -87,8 +87,12 @@ The simplest way to control a model with the CLI is to wrap it in the LightningC # simple demo classes for your convenience from pytorch_lightning.demos.boring_classes import DemoModel, BoringDataModule - cli = LightningCLI(DemoModel, BoringDataModule) - # note: don't call fit!! + def cli_main(): + cli = LightningCLI(DemoModel, BoringDataModule) + # note: don't call fit!! + + if __name__ == "__main__": + cli_main() Now your model can be managed via the CLI. To see the available commands type: diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index f2380d77dd838..eeedd196f8d38 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import sys from functools import partial, update_wrapper from types import MethodType from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union @@ -48,6 +49,9 @@ locals()["Namespace"] = object +ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]] + + class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau): def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None: super().__init__(optimizer, *args, **kwargs) @@ -256,7 +260,7 @@ def __init__( parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, - args: Union[List[str], Dict[str, Any], Namespace] = None, + args: ArgsType = None, run: bool = True, auto_registry: bool = False, ) -> None: @@ -302,7 +306,7 @@ def __init__( `_ of the given class. args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style - arguments can be given in a ``list``. Alternatively a structured config options can be given in a + arguments can be given in a ``list``. Alternatively, structured config options can be given in a ``dict`` or ``jsonargparse.Namespace``. run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` method. If set to ``False``, the trainer and model classes will be instantiated only. @@ -478,10 +482,14 @@ def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: add_class_path = _add_class_path_generator(class_type) parser.link_arguments(key, link_to, compute_fn=add_class_path) - def parse_arguments( - self, parser: LightningArgumentParser, args: Union[List[str], Dict[str, Any], Namespace] - ) -> None: + def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> None: """Parses command line arguments and stores it in ``self.config``.""" + if args is not None and len(sys.argv) > 1: + raise MisconfigurationException( + "LightningCLI's args parameter is intended to run from within Python like if it were from the command " + "line. To prevent mistakes it is not allowed to provide both args and command line arguments, got: " + f"sys.argv[1:]={sys.argv[1:]}, args={args}." + ) if isinstance(args, (dict, Namespace)): self.config = parser.parse_object(args) else: diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 95891e214c733..8cb5f5e23f6f7 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1414,7 +1414,15 @@ def test_pytorch_profiler_init_args(): ], ) def test_lightning_cli_with_args_given(args): - cli = LightningCLI(TestModel, run=False, args=args) + with mock.patch("sys.argv", [""]): + cli = LightningCLI(TestModel, run=False, args=args) assert isinstance(cli.model, TestModel) assert cli.config.trainer.logger is False assert cli.model.foo == 456 + + +def test_lightning_cli_args_and_sys_argv_exception(): + with mock.patch("sys.argv", ["", "--model.foo=456"]), pytest.raises( + MisconfigurationException, match="LightningCLI's args parameter " + ): + LightningCLI(TestModel, run=False, args=["--model.foo=789"]) From b3210f3a39b9e2afb6a0bf58cfe86a005e437817 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 19 Sep 2022 05:20:57 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/source-pytorch/cli/lightning_cli_advanced_3.rst | 2 ++ docs/source-pytorch/cli/lightning_cli_intermediate.rst | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst index 7a89519928def..ecd92bc8e7356 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst @@ -368,10 +368,12 @@ the main function as follows: from pytorch_lightning.cli import ArgsType, LightningCLI + def cli_main(args: ArgsType = None): cli = LightningCLI(MyModel, ..., args=args) ... + if __name__ == "__main__": cli_main() diff --git a/docs/source-pytorch/cli/lightning_cli_intermediate.rst b/docs/source-pytorch/cli/lightning_cli_intermediate.rst index df637715f925b..f89dccd8b14bf 100644 --- a/docs/source-pytorch/cli/lightning_cli_intermediate.rst +++ b/docs/source-pytorch/cli/lightning_cli_intermediate.rst @@ -87,10 +87,12 @@ The simplest way to control a model with the CLI is to wrap it in the LightningC # simple demo classes for your convenience from pytorch_lightning.demos.boring_classes import DemoModel, BoringDataModule + def cli_main(): cli = LightningCLI(DemoModel, BoringDataModule) # note: don't call fit!! + if __name__ == "__main__": cli_main() From 96db5c57152d7c2ddca28227b72cb2085c6db33f Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Mon, 19 Sep 2022 07:22:18 +0200 Subject: [PATCH 6/8] Fix Namespace import in cli tests. --- tests/tests_pytorch/test_cli.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 8cb5f5e23f6f7..cad783c2276f6 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -50,6 +50,8 @@ if _JSONARGPARSE_SIGNATURES_AVAILABLE: from jsonargparse import lazy_instance, Namespace +else: + from argparse import Namespace @contextmanager From 02ea76997a21ff76731bb1bda7cedbf14051444a Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Mon, 19 Sep 2022 17:33:10 +0200 Subject: [PATCH 7/8] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/pytorch_lightning/cli.py | 2 +- tests/tests_pytorch/test_cli.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/cli.py b/src/pytorch_lightning/cli.py index eeedd196f8d38..22d7d8b637d96 100644 --- a/src/pytorch_lightning/cli.py +++ b/src/pytorch_lightning/cli.py @@ -485,7 +485,7 @@ def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None: def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> None: """Parses command line arguments and stores it in ``self.config``.""" if args is not None and len(sys.argv) > 1: - raise MisconfigurationException( + raise ValueError( "LightningCLI's args parameter is intended to run from within Python like if it were from the command " "line. To prevent mistakes it is not allowed to provide both args and command line arguments, got: " f"sys.argv[1:]={sys.argv[1:]}, args={args}." diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index cad783c2276f6..93cf0928d26fe 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -1425,6 +1425,6 @@ def test_lightning_cli_with_args_given(args): def test_lightning_cli_args_and_sys_argv_exception(): with mock.patch("sys.argv", ["", "--model.foo=456"]), pytest.raises( - MisconfigurationException, match="LightningCLI's args parameter " + ValueError, match="LightningCLI's args parameter " ): LightningCLI(TestModel, run=False, args=["--model.foo=789"]) From 9843c583afddc3608c5260c212c22208424b62c5 Mon Sep 17 00:00:00 2001 From: Mauricio Villegas Date: Mon, 19 Sep 2022 17:44:10 +0200 Subject: [PATCH 8/8] Address comments in pull request. --- docs/source-pytorch/cli/lightning_cli_advanced_3.rst | 12 ++++++++---- .../cli/lightning_cli_intermediate.rst | 1 + 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst index ecd92bc8e7356..7f5ed143869e3 100644 --- a/docs/source-pytorch/cli/lightning_cli_advanced_3.rst +++ b/docs/source-pytorch/cli/lightning_cli_advanced_3.rst @@ -388,7 +388,8 @@ Then it is possible to import the ``cli_main`` function to run it. Executing in All the features that are supported from the command line can be used when giving ``args`` as a list of strings. It is also possible to provide a ``dict`` or `jsonargparse.Namespace -`__. For example: +`__. For example in a jupyter notebook someone +might do: .. code:: python @@ -399,9 +400,12 @@ also possible to provide a ``dict`` or `jsonargparse.Namespace "model": {}, } - for encoder_layers in [8, 16, 24]: - args["model"]["encoder_layers"] = encoder_layers - cli_main(args) + args["model"]["encoder_layers"] = 8 + cli_main(args) + args["model"]["encoder_layers"] = 12 + cli_main(args) + args["trainer"]["max_epochs"] = 200 + cli_main(args) .. note:: diff --git a/docs/source-pytorch/cli/lightning_cli_intermediate.rst b/docs/source-pytorch/cli/lightning_cli_intermediate.rst index f89dccd8b14bf..db8b6cf4c77ec 100644 --- a/docs/source-pytorch/cli/lightning_cli_intermediate.rst +++ b/docs/source-pytorch/cli/lightning_cli_intermediate.rst @@ -95,6 +95,7 @@ The simplest way to control a model with the CLI is to wrap it in the LightningC if __name__ == "__main__": cli_main() + # note: it is good practice to implement the CLI in a function and call it in the main if block Now your model can be managed via the CLI. To see the available commands type: