Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LightningCLI(run=False|True) #8751

Merged
merged 8 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662))


- Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))


- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))

Expand Down
35 changes: 25 additions & 10 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import torch
from unittest import mock
from typing import List
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning import LightningModule, LightningDataModule, Trainer
from pytorch_lightning.utilities.cli import LightningCLI

original_fit = LightningCLI.fit
LightningCLI.fit = lambda self: None
cli_fit = LightningCLI.fit
LightningCLI.fit = lambda *_, **__: None
trainer_fit = Trainer.fit
Trainer.fit = lambda *_, **__: None


class MyModel(LightningModule):
Expand Down Expand Up @@ -47,7 +48,8 @@

.. testcleanup:: *

LightningCLI.fit = original_fit
LightningCLI.fit = cli_fit
Trainer.fit = trainer_fit
mock_argv.stop()


Expand Down Expand Up @@ -260,17 +262,30 @@ file. Loading a defaults file :code:`my_cli_defaults.yaml` in the current workin

.. testcode::

cli = LightningCLI(
MyModel,
MyDataModule,
parser_kwargs={"default_config_files": ["my_cli_defaults.yaml"]},
)
cli = LightningCLI(MyModel, MyDataModule, parser_kwargs={"default_config_files": ["my_cli_defaults.yaml"]})
carmocca marked this conversation as resolved.
Show resolved Hide resolved

To load a file in the user's home directory would be just changing to :code:`~/.my_cli_defaults.yaml`. Note that this
setting is given through :code:`parser_kwargs`. More parameters are supported. For details see the `ArgumentParser API
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.__init__>`_ documentation.


Instantiation only mode
^^^^^^^^^^^^^^^^^^^^^^^

The CLI is designed to start fitting with minimal code changes. On class instantiation, the CLI will automatically
call ``trainer.fit(...)`` internally so you don't have to do it. To avoid this, you can set the following argument:

.. testcode::

cli = LightningCLI(MyModel, run=False) # True by default
# you'll have to call fit yourself:
cli.trainer.fit(cli.model)


This can be useful to implement custom logic without having to subclass the CLI, but still using the CLI's instantiation
and argument parsing capabilities.


Trainer Callbacks and arguments with class type
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def __init__(
parser_kwargs: Optional[Dict[str, Any]] = None,
subclass_mode_model: bool = False,
subclass_mode_data: bool = False,
run: bool = True,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
Expand Down Expand Up @@ -258,6 +259,8 @@ def __init__(
subclass_mode_data: Whether datamodule can be any `subclass
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
of the given class.
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
"""
self.model_class = model_class
self.datamodule_class = datamodule_class
Expand All @@ -284,10 +287,11 @@ def __init__(
self.instantiate_classes()
self.add_configure_optimizers_method_to_model()

self.prepare_fit_kwargs()
self.before_fit()
self.fit()
self.after_fit()
if run:
self.prepare_fit_kwargs()
self.before_fit()
self.fit()
self.after_fit()

def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
"""Method that instantiates the argument parser."""
Expand Down
9 changes: 9 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,12 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
assert isinstance(cli.model.optim1, torch.optim.Adam)
assert isinstance(cli.model.optim2, torch.optim.SGD)
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)


@pytest.mark.parametrize("run", (False, True))
def test_lightning_cli_disabled_run(run):
with mock.patch("sys.argv", ["any.py"]), mock.patch("pytorch_lightning.Trainer.fit") as fit_mock:
cli = LightningCLI(BoringModel, run=run)
fit_mock.call_count == run
assert isinstance(cli.trainer, Trainer)
assert isinstance(cli.model, LightningModule)