Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added args parameter to LightningCLI to ease running from within Python #14596

Merged
merged 14 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions docs/source-pytorch/cli/lightning_cli_advanced_3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,61 @@ You can also pass the class path directly, for example, if the optimizer hasn't
--optimizer1.lr=0.01 \
--optimizer2=torch.optim.AdamW \
--optimizer2.lr=0.0001


Run from Python
^^^^^^^^^^^^^^^

Even though the :class:`~pytorch_lightning.cli.LightningCLI` class is designed to help in the implementation of command
line tools, for some use cases it is desired to run directly from Python. To allow this there is the ``args`` parameter.
An example could be to first implement a normal CLI script, but adding an ``args`` parameter with default ``None`` to
the main function as follows:

.. code:: python

from pytorch_lightning.cli import ArgsType, LightningCLI


def cli_main(args: ArgsType = None):
cli = LightningCLI(MyModel, ..., args=args)
...


if __name__ == "__main__":
cli_main()

Then it is possible to import the ``cli_main`` function to run it. Executing in a shell ``my_cli.py
--trainer.max_epochs=100", "--model.encoder_layers=24`` would be equivalent to:

.. code:: python

from my_module.my_cli import cli_main

cli_main(["--trainer.max_epochs=100", "--model.encoder_layers=24"])

All the features that are supported from the command line can be used when giving ``args`` as a list of strings. It is
also possible to provide a ``dict`` or `jsonargparse.Namespace
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.Namespace>`__. For example in a jupyter notebook someone
might do:

.. code:: python

args = {
"trainer": {
"max_epochs": 100,
},
"model": {},
}

args["model"]["encoder_layers"] = 8
cli_main(args)
args["model"]["encoder_layers"] = 12
cli_main(args)
args["trainer"]["max_epochs"] = 200
cli_main(args)

.. note::

The ``args`` parameter must be ``None`` when running from command line so that ``sys.argv`` is used as arguments.
Also, note that the purpose of ``trainer_defaults`` is different to ``args``. It is okay to use ``trainer_defaults``
in the ``cli_main`` function to modify the defaults of some trainer parameters.
11 changes: 9 additions & 2 deletions docs/source-pytorch/cli/lightning_cli_intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,15 @@ The simplest way to control a model with the CLI is to wrap it in the LightningC
# simple demo classes for your convenience
from pytorch_lightning.demos.boring_classes import DemoModel, BoringDataModule

cli = LightningCLI(DemoModel, BoringDataModule)
# note: don't call fit!!

def cli_main():
cli = LightningCLI(DemoModel, BoringDataModule)
# note: don't call fit!!


if __name__ == "__main__":
carmocca marked this conversation as resolved.
Show resolved Hide resolved
cli_main()
# note: it is good practice to implement the CLI in a function and call it in the main if block

Now your model can be managed via the CLI. To see the available commands type:

Expand Down
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
[#14620](https://github.com/Lightning-AI/lightning/issues/14620))


- Added `args` parameter to `LightningCLI` to ease running from within Python ([#14596](https://github.com/PyTorchLightning/pytorch-lightning/pull/14596))


- Added `WandbLogger.download_artifact` and `WandbLogger.use_artifact` for managing artifacts with Weights and Biases ([#14551](https://github.com/Lightning-AI/lightning/issues/14551))


Expand Down
23 changes: 20 additions & 3 deletions src/pytorch_lightning/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from functools import partial, update_wrapper
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
Expand Down Expand Up @@ -48,6 +49,9 @@
locals()["Namespace"] = object


ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
super().__init__(optimizer, *args, **kwargs)
Expand Down Expand Up @@ -256,6 +260,7 @@ def __init__(
parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None,
subclass_mode_model: bool = False,
subclass_mode_data: bool = False,
args: ArgsType = None,
run: bool = True,
auto_registry: bool = False,
) -> None:
Expand Down Expand Up @@ -300,6 +305,9 @@ def __init__(
subclass_mode_data: Whether datamodule can be any `subclass
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
of the given class.
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style
arguments can be given in a ``list``. Alternatively, structured config options can be given in a
``dict`` or ``jsonargparse.Namespace``.
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
auto_registry: Whether to automatically fill up the registries with all defined subclasses.
Expand Down Expand Up @@ -338,7 +346,7 @@ def __init__(
{"description": description, "env_prefix": env_prefix, "default_env": env_parse},
)
self.setup_parser(run, main_kwargs, subparser_kwargs)
self.parse_arguments(self.parser)
self.parse_arguments(self.parser, args)

self.subcommand = self.config["subcommand"] if run else None

Expand Down Expand Up @@ -474,9 +482,18 @@ def link_optimizers_and_lr_schedulers(parser: LightningArgumentParser) -> None:
add_class_path = _add_class_path_generator(class_type)
parser.link_arguments(key, link_to, compute_fn=add_class_path)

def parse_arguments(self, parser: LightningArgumentParser) -> None:
def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> None:
"""Parses command line arguments and stores it in ``self.config``."""
self.config = parser.parse_args()
if args is not None and len(sys.argv) > 1:
raise ValueError(
"LightningCLI's args parameter is intended to run from within Python like if it were from the command "
"line. To prevent mistakes it is not allowed to provide both args and command line arguments, got: "
f"sys.argv[1:]={sys.argv[1:]}, args={args}."
)
if isinstance(args, (dict, Namespace)):
self.config = parser.parse_object(args)
else:
self.config = parser.parse_args(args)

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""
Expand Down
27 changes: 26 additions & 1 deletion tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
from tests_pytorch.helpers.utils import no_warning_call

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
from jsonargparse import lazy_instance
from jsonargparse import lazy_instance, Namespace
else:
from argparse import Namespace


@contextmanager
Expand Down Expand Up @@ -1403,3 +1405,26 @@ def test_pytorch_profiler_init_args():
init["record_shapes"] = unresolved.pop("record_shapes") # Test move to init_args
assert {k: cli.config.trainer.profiler.init_args[k] for k in init} == init
assert cli.config.trainer.profiler.dict_kwargs == unresolved


@pytest.mark.parametrize(
["args"],
[
(["--trainer.logger=False", "--model.foo=456"],),
({"trainer": {"logger": False}, "model": {"foo": 456}},),
(Namespace(trainer=Namespace(logger=False), model=Namespace(foo=456)),),
],
)
def test_lightning_cli_with_args_given(args):
with mock.patch("sys.argv", [""]):
cli = LightningCLI(TestModel, run=False, args=args)
assert isinstance(cli.model, TestModel)
assert cli.config.trainer.logger is False
assert cli.model.foo == 456


def test_lightning_cli_args_and_sys_argv_exception():
with mock.patch("sys.argv", ["", "--model.foo=456"]), pytest.raises(
ValueError, match="LightningCLI's args parameter "
):
LightningCLI(TestModel, run=False, args=["--model.foo=789"])