From 05b36c0b7b82d7143df664f5b70c596fbc628046 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 19 Oct 2021 04:13:31 +0200 Subject: [PATCH 1/4] [CLI] Shorthand notation to instantiate datamodules --- CHANGELOG.md | 2 + docs/source/common/lightning_cli.rst | 13 ++++-- pytorch_lightning/utilities/cli.py | 25 +++++++++--- tests/utilities/test_cli.py | 60 ++++++++++++++++++++++++++++ 4 files changed, 90 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 888d22a520f75..662374429eaf4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -64,6 +64,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) * Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) * Support passing lists of callbacks via command line ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) + * Support shorthand notation to instantiate models ([#9588](https://github.com/PyTorchLightning/pytorch-lightning/pull/9588)) + * Support shorthand notation to instantiate datamodules ([#10004](https://github.com/PyTorchLightning/pytorch-lightning/pull/10004)) - Fault-tolerant training: diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 8977f8274eb6f..ef8a3109d0742 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -436,13 +436,13 @@ In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for multiple models and datasets. -The model argument can be left unset if a model has been registered first, this is particularly interesting for library -authors who want to provide their users a range of models to choose from: +The model and datamodule arguments can be left unset if a class has been registered first, +this is particularly interesting for library authors who want to provide their users a range of models to choose from: .. code-block:: python import flash.image - from pytorch_lightning.utilities.cli import MODEL_REGISTRY + from pytorch_lightning.utilities.cli import MODEL_REGISTRY, DATAMODULE_REGISTRY @MODEL_REGISTRY @@ -450,6 +450,11 @@ authors who want to provide their users a range of models to choose from: ... + @DATAMODULE_REGISTRY + class MyData(LightningDataModule): + ... + + # register all `LightningModule` subclasses from a package MODEL_REGISTRY.register_classes(flash.image, LightningModule) # print(MODEL_REGISTRY) @@ -459,7 +464,7 @@ authors who want to provide their users a range of models to choose from: .. code-block:: bash - $ python trainer.py fit --model=MyModel --model.feat_dim=64 + $ python trainer.py fit --model=MyModel --model.feat_dim=64 --data=MyData .. note:: diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 8f16aabcb8aae..adfb146e0424f 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -91,6 +91,8 @@ def __str__(self) -> str: MODEL_REGISTRY = _Registry() +DATAMODULE_REGISTRY = _Registry() + class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" @@ -129,6 +131,7 @@ def add_lightning_class_args( ], nested_key: str, subclass_mode: bool = False, + required: bool = True, ) -> List[str]: """Adds arguments from a lightning class to a nested key of the parser. @@ -136,6 +139,7 @@ def add_lightning_class_args( lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. nested_key: Name of the nested namespace to store arguments. subclass_mode: Whether allow any subclass of the given class. + required: Whether the argument group is required. Returns: A list with the names of the class arguments added. @@ -149,7 +153,7 @@ def add_lightning_class_args( if issubclass(lightning_class, Callback): self.callback_keys.append(nested_key) if subclass_mode: - return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=True) + return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required) return self.add_class_arguments( lightning_class, nested_key, fail_untyped=False, instantiate=not issubclass(lightning_class, Trainer) ) @@ -432,7 +436,7 @@ def __init__( called. If ``None``, you can pass a registered model with ``--model=MyModel``. datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when - called. + called. If ``None``, you can pass a registered model with ``--datamodule=MyDataModule``. save_config_callback: A callback class to save the training config. save_config_filename: Filename for the config file. save_config_overwrite: Whether to overwrite an existing config file. @@ -455,7 +459,6 @@ def __init__( run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` method. If set to ``False``, the trainer and model classes will be instantiated only. """ - self.datamodule_class = datamodule_class self.save_config_callback = save_config_callback self.save_config_filename = save_config_filename self.save_config_overwrite = save_config_overwrite @@ -463,13 +466,17 @@ def __init__( self.trainer_class = trainer_class self.trainer_defaults = trainer_defaults or {} self.seed_everything_default = seed_everything_default - self.subclass_mode_data = subclass_mode_data self.model_class = model_class # used to differentiate between the original value and the processed value self._model_class = model_class or LightningModule self.subclass_mode_model = (model_class is None) or subclass_mode_model + self.datamodule_class = datamodule_class + # used to differentiate between the original value and the processed value + self._datamodule_class = datamodule_class or LightningDataModule + self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data + main_kwargs, subparser_kwargs = self._setup_parser_kwargs( parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463 {"description": description, "env_prefix": env_prefix, "default_env": env_parse}, @@ -531,12 +538,18 @@ def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: parser.set_defaults(trainer_defaults) parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model) - if self.model_class is None and MODEL_REGISTRY: + if self.model_class is None and len(MODEL_REGISTRY): # did not pass a model and there are models registered parser.set_choices("model", MODEL_REGISTRY.classes) if self.datamodule_class is not None: - parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data) + parser.add_lightning_class_args(self._datamodule_class, "data", subclass_mode=self.subclass_mode_data) + elif len(DATAMODULE_REGISTRY): + # this should not be required because the user might want to use the `LightningModule` dataloaders + parser.add_lightning_class_args( + self._datamodule_class, "data", subclass_mode=self.subclass_mode_data, required=False + ) + parser.set_choices("data", DATAMODULE_REGISTRY.classes) def _add_arguments(self, parser: LightningArgumentParser) -> None: # default + core + custom arguments diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index b7c443fb24d4a..e26f38356a84e 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -36,6 +36,7 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import ( CALLBACK_REGISTRY, + DATAMODULE_REGISTRY, instantiate_class, LightningArgumentParser, LightningCLI, @@ -915,6 +916,65 @@ def test_lightning_cli_model_choices(): assert cli.model.bar == 5 +@DATAMODULE_REGISTRY +class MyDataModule(BoringDataModule): + def __init__(self, foo, bar=5): + super().__init__() + self.foo = foo + self.bar = bar + + +DATAMODULE_REGISTRY(cls=BoringDataModule) + + +def test_lightning_cli_datamodule_choices(): + # with set model + with mock.patch("sys.argv", ["any.py", "fit", "--data=BoringDataModule"]), mock.patch( + "pytorch_lightning.Trainer._fit_impl" + ) as run: + cli = LightningCLI(BoringModel, trainer_defaults={"fast_dev_run": 1}) + assert isinstance(cli.datamodule, BoringDataModule) + run.assert_called_once_with(ANY, ANY, ANY, cli.datamodule) + + with mock.patch("sys.argv", ["any.py", "--data=MyDataModule", "--data.foo", "123"]): + cli = LightningCLI(BoringModel, run=False) + assert isinstance(cli.datamodule, MyDataModule) + assert cli.datamodule.foo == 123 + assert cli.datamodule.bar == 5 + + # with configurable model + with mock.patch("sys.argv", ["any.py", "fit", "--model", "BoringModel", "--data=BoringDataModule"]), mock.patch( + "pytorch_lightning.Trainer._fit_impl" + ) as run: + cli = LightningCLI(trainer_defaults={"fast_dev_run": 1}) + assert isinstance(cli.model, BoringModel) + assert isinstance(cli.datamodule, BoringDataModule) + run.assert_called_once_with(cli.model, ANY, ANY, cli.datamodule) + + with mock.patch("sys.argv", ["any.py", "--model", "BoringModel", "--data=MyDataModule"]): + cli = LightningCLI(run=False) + assert isinstance(cli.model, BoringModel) + assert isinstance(cli.datamodule, MyDataModule) + + assert len(DATAMODULE_REGISTRY) # needs a value initially added + with mock.patch("sys.argv", ["any.py"]): + cli = LightningCLI(BoringModel, run=False) + # data was not passed but we are adding it automatically because there are datamodules registered + assert "data" in cli.parser.groups + assert not hasattr(cli.parser.groups["data"], "group_class") + + with mock.patch("sys.argv", ["any.py"]), mock.patch.dict(DATAMODULE_REGISTRY, clear=True): + cli = LightningCLI(BoringModel, run=False) + # no registered classes so not added automatically + assert "data" not in cli.parser.groups + assert len(DATAMODULE_REGISTRY) # check state was not modified + + with mock.patch("sys.argv", ["any.py"]): + cli = LightningCLI(BoringModel, BoringDataModule, run=False) + # since we are passing the DataModule, that's whats added to the parser + assert cli.parser.groups["data"].group_class is BoringDataModule + + @pytest.mark.parametrize("use_class_path_callbacks", [False, True]) def test_registries_resolution(use_class_path_callbacks): """This test validates registries are used when simplified command line are being used.""" From 105686bf63fb26dae2ca26851fa3e03d292fbacc Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 19 Oct 2021 04:15:01 +0200 Subject: [PATCH 2/4] Fix CHANGELOG num --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 662374429eaf4..1a7a9ce1a2952 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,7 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565)) * Support passing lists of callbacks via command line ([#8815](https://github.com/PyTorchLightning/pytorch-lightning/pull/8815)) * Support shorthand notation to instantiate models ([#9588](https://github.com/PyTorchLightning/pytorch-lightning/pull/9588)) - * Support shorthand notation to instantiate datamodules ([#10004](https://github.com/PyTorchLightning/pytorch-lightning/pull/10004)) + * Support shorthand notation to instantiate datamodules ([#10011](https://github.com/PyTorchLightning/pytorch-lightning/pull/10011)) - Fault-tolerant training: From 11bb89c0178c81afe84cbb5d94af437d3436e701 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 19 Oct 2021 04:16:50 +0200 Subject: [PATCH 3/4] Docs --- docs/source/common/lightning_cli.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index ef8a3109d0742..2f7b2bae599e4 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -436,8 +436,8 @@ In the previous examples :class:`~pytorch_lightning.utilities.cli.LightningCLI` datamodule class. However, there are many cases in which the objective is to easily be able to run many experiments for multiple models and datasets. -The model and datamodule arguments can be left unset if a class has been registered first, -this is particularly interesting for library authors who want to provide their users a range of models to choose from: +The model and datamodule arguments can be left unset if a class has been registered first. +This is particularly interesting for library authors who want to provide their users a range of models to choose from: .. code-block:: python From c4c619d387170915749c3dc82d55054bcdf28130 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Tue, 19 Oct 2021 15:05:53 +0200 Subject: [PATCH 4/4] Update pytorch_lightning/utilities/cli.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/utilities/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index adfb146e0424f..b6c3b22d7bfb4 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -436,7 +436,7 @@ def __init__( called. If ``None``, you can pass a registered model with ``--model=MyModel``. datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when - called. If ``None``, you can pass a registered model with ``--datamodule=MyDataModule``. + called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. save_config_callback: A callback class to save the training config. save_config_filename: Filename for the config file. save_config_overwrite: Whether to overwrite an existing config file.