diff --git a/CHANGELOG.md b/CHANGELOG.md index a83ef6a55d515..463a5f41c1e70 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -238,6 +238,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue where the CLI fails with certain torch objects ([#13153](https://github.com/PyTorchLightning/pytorch-lightning/pull/13153)) +- Fixed ``LightningCLI`` signature parameter resolving for some lightning classes ([#13283](https://github.com/PyTorchLightning/pytorch-lightning/pull/13283)) + + - Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350)) diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index e7d54903c2932..9da03791d6429 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -3,6 +3,6 @@ matplotlib>3.1, <3.5.3 torchtext>=0.10.*, <=0.12.0 omegaconf>=2.0.5, <=2.1.* hydra-core>=1.0.5, <=1.1.* -jsonargparse[signatures]>=4.9.0, <=4.9.0 +jsonargparse[signatures]>=4.10.0, <=4.10.0 gcsfs>=2021.5.0, <=2022.2.0 rich>=10.2.2, !=10.15.0.a, <13.0.0 diff --git a/src/pytorch_lightning/utilities/cli.py b/src/pytorch_lightning/utilities/cli.py index d9386c4dd7d25..a5d5fe3b66960 100644 --- a/src/pytorch_lightning/utilities/cli.py +++ b/src/pytorch_lightning/utilities/cli.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn -_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.9.0") +_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.0") if _JSONARGPARSE_SIGNATURES_AVAILABLE: import docstring_parser diff --git a/tests/tests_pytorch/utilities/test_cli.py b/tests/tests_pytorch/utilities/test_cli.py index 1dfab764842a4..f499acf41e6d0 100644 --- a/tests/tests_pytorch/utilities/test_cli.py +++ b/tests/tests_pytorch/utilities/test_cli.py @@ -34,11 +34,14 @@ from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE, TensorBoardLogger from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.profiler import PyTorchProfiler +from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.cli import ( + _JSONARGPARSE_SIGNATURES_AVAILABLE, instantiate_class, LightningArgumentParser, LightningCLI, @@ -54,6 +57,9 @@ if _TORCHVISION_AVAILABLE: torchvision_version = version.parse(__import__("torchvision").__version__) +if _JSONARGPARSE_SIGNATURES_AVAILABLE: + from jsonargparse import lazy_instance + @contextmanager def mock_subclasses(baseclass, *subclasses): @@ -1350,8 +1356,6 @@ def configure_optimizers(self, optimizer, lr_scheduler=None): def test_cli_parameter_with_lazy_instance_default(): - from jsonargparse import lazy_instance - class TestModel(BoringModel): def __init__(self, activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05)): super().__init__() @@ -1367,6 +1371,21 @@ def __init__(self, activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReL assert cli.model.activation is not model.activation +def test_ddpstrategy_instantiation_and_find_unused_parameters(): + strategy_default = lazy_instance(DDPStrategy, find_unused_parameters=True) + with mock.patch("sys.argv", ["any.py", "--trainer.strategy.process_group_backend=group"]): + cli = LightningCLI( + BoringModel, + run=False, + trainer_defaults={"strategy": strategy_default}, + ) + + assert cli.config.trainer.strategy.init_args.find_unused_parameters is True + assert isinstance(cli.config_init.trainer.strategy, DDPStrategy) + assert cli.config_init.trainer.strategy.process_group_backend == "group" + assert strategy_default is not cli.config_init.trainer.strategy + + def test_cli_logger_shorthand(): with mock.patch("sys.argv", ["any.py"]): cli = LightningCLI(TestModel, run=False, trainer_defaults={"logger": False}) @@ -1381,6 +1400,41 @@ def test_cli_logger_shorthand(): assert cli.trainer.logger is None +def _test_logger_init_args(logger_name, init, unresolved={}): + cli_args = [f"--trainer.logger={logger_name}"] + cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()] + cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()] + cli_args.append("--print_config") + + out = StringIO() + with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stdout(out), pytest.raises(SystemExit): + LightningCLI(TestModel, run=False) + + data = yaml.safe_load(out.getvalue())["trainer"]["logger"] + assert {k: data["init_args"][k] for k in init} == init + if unresolved: + assert data["dict_kwargs"] == unresolved + + +@pytest.mark.skipif(not _COMET_AVAILABLE, reason="comet-ml is required") +def test_comet_logger_init_args(): + _test_logger_init_args("CometLogger", {"save_dir": "comet", "workspace": "comet"}) + + +@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="neptune-client is required") +def test_neptune_logger_init_args(): + _test_logger_init_args("NeptuneLogger", {"name": "neptune"}, {"description": "neptune"}) + + +def test_tensorboard_logger_init_args(): + _test_logger_init_args("TensorBoardLogger", {"save_dir": "tb", "name": "tb"}) + + +@pytest.mark.skipif(not _WANDB_AVAILABLE, reason="wandb is required") +def test_wandb_logger_init_args(): + _test_logger_init_args("WandbLogger", {"save_dir": "wandb", "notes": "wandb"}) + + def test_cli_auto_seeding(): with mock.patch("sys.argv", ["any.py"]): cli = LightningCLI(TestModel, run=False, seed_everything_default=False) @@ -1440,3 +1494,25 @@ def __init__(self, a_func: Callable = torch.softmax): LightningCLI(TestModel, run=False) assert "a_func: torch.softmax" in out.getvalue() + + +def test_pytorch_profiler_init_args(): + init = { + "dirpath": "profiler", + "row_limit": 10, + "group_by_input_shapes": True, + } + unresolved = { + "profile_memory": True, + "record_shapes": True, + } + cli_args = ["--trainer.profiler=PyTorchProfiler"] + cli_args += [f"--trainer.profiler.{k}={v}" for k, v in init.items()] + cli_args += [f"--trainer.profiler.dict_kwargs.{k}={v}" for k, v in unresolved.items()] + + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = LightningCLI(TestModel, run=False) + + assert isinstance(cli.config_init.trainer.profiler, PyTorchProfiler) + assert {k: cli.config.trainer.profiler.init_args[k] for k in init} == init + assert cli.config.trainer.profiler.dict_kwargs == unresolved