diff --git a/CHANGELOG.md b/CHANGELOG.md index cb05805de3c16..3055b15011a2f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662)) +- Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751)) + + - Fault-tolerant training: * Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366)) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 61153c02f3113..a8d65ce0ec853 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -4,12 +4,13 @@ import torch from unittest import mock from typing import List - from pytorch_lightning.core.lightning import LightningModule - from pytorch_lightning.core.datamodule import LightningDataModule + from pytorch_lightning import LightningModule, LightningDataModule, Trainer from pytorch_lightning.utilities.cli import LightningCLI - original_fit = LightningCLI.fit - LightningCLI.fit = lambda self: None + cli_fit = LightningCLI.fit + LightningCLI.fit = lambda *_, **__: None + trainer_fit = Trainer.fit + Trainer.fit = lambda *_, **__: None class MyModel(LightningModule): @@ -47,7 +48,8 @@ .. testcleanup:: * - LightningCLI.fit = original_fit + LightningCLI.fit = cli_fit + Trainer.fit = trainer_fit mock_argv.stop() @@ -260,17 +262,30 @@ file. Loading a defaults file :code:`my_cli_defaults.yaml` in the current workin .. testcode:: - cli = LightningCLI( - MyModel, - MyDataModule, - parser_kwargs={"default_config_files": ["my_cli_defaults.yaml"]}, - ) + cli = LightningCLI(MyModel, MyDataModule, parser_kwargs={"default_config_files": ["my_cli_defaults.yaml"]}) To load a file in the user's home directory would be just changing to :code:`~/.my_cli_defaults.yaml`. Note that this setting is given through :code:`parser_kwargs`. More parameters are supported. For details see the `ArgumentParser API `_ documentation. +Instantiation only mode +^^^^^^^^^^^^^^^^^^^^^^^ + +The CLI is designed to start fitting with minimal code changes. On class instantiation, the CLI will automatically +call ``trainer.fit(...)`` internally so you don't have to do it. To avoid this, you can set the following argument: + +.. testcode:: + + cli = LightningCLI(MyModel, run=False) # True by default + # you'll have to call fit yourself: + cli.trainer.fit(cli.model) + + +This can be useful to implement custom logic without having to subclass the CLI, but still using the CLI's instantiation +and argument parsing capabilities. + + Trainer Callbacks and arguments with class type ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 752d0092d1baf..6cbb12340dc54 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -211,6 +211,7 @@ def __init__( parser_kwargs: Optional[Dict[str, Any]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, + run: bool = True, ) -> None: """ Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are @@ -258,6 +259,8 @@ def __init__( subclass_mode_data: Whether datamodule can be any `subclass `_ of the given class. + run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` + method. If set to ``False``, the trainer and model classes will be instantiated only. """ self.model_class = model_class self.datamodule_class = datamodule_class @@ -284,10 +287,11 @@ def __init__( self.instantiate_classes() self.add_configure_optimizers_method_to_model() - self.prepare_fit_kwargs() - self.before_fit() - self.fit() - self.after_fit() + if run: + self.prepare_fit_kwargs() + self.before_fit() + self.fit() + self.after_fit() def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index cc636aa9a17ed..e1d8bda010e88 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -687,3 +687,12 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): assert isinstance(cli.model.optim1, torch.optim.Adam) assert isinstance(cli.model.optim2, torch.optim.SGD) assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR) + + +@pytest.mark.parametrize("run", (False, True)) +def test_lightning_cli_disabled_run(run): + with mock.patch("sys.argv", ["any.py"]), mock.patch("pytorch_lightning.Trainer.fit") as fit_mock: + cli = LightningCLI(BoringModel, run=run) + fit_mock.call_count == run + assert isinstance(cli.trainer, Trainer) + assert isinstance(cli.model, LightningModule)