diff --git a/CHANGELOG.md b/CHANGELOG.md index f984b542b7c01..ccb85862152f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -244,6 +244,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed issue where the CLI could not pass a `Profiler` to the `Trainer` ([#13084](https://github.com/PyTorchLightning/pytorch-lightning/pull/13084)) +- Fixed issue where the CLI fails with certain torch objects ([#13153](https://github.com/PyTorchLightning/pytorch-lightning/pull/13153)) + + - Fixed logging's step values when multiple dataloaders are used during evaluation ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index be0771ee40efb..ee4b769634b2e 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -34,14 +34,22 @@ from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn -_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.7.1") +_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.8.0") if _JSONARGPARSE_SIGNATURES_AVAILABLE: import docstring_parser - from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode + from jsonargparse import ( + ActionConfigFile, + ArgumentParser, + class_from_function, + Namespace, + register_unresolvable_import_paths, + set_config_read_mode, + ) from jsonargparse.typehints import get_all_subclass_paths from jsonargparse.util import import_object + register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483 set_config_read_mode(fsspec_enabled=True) else: locals()["ArgumentParser"] = object diff --git a/requirements/extra.txt b/requirements/extra.txt index cef58c6c21221..c8038c841f7a2 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -4,6 +4,6 @@ matplotlib>3.1, <3.5.3 torchtext>=0.9.*, <=0.12.0 omegaconf>=2.0.5, <=2.1.* hydra-core>=1.0.5, <=1.1.* -jsonargparse[signatures]>=4.7.1, <4.7.4 +jsonargparse[signatures]>=4.8.0, <=4.8.0 gcsfs>=2021.5.0, <=2022.2.0 rich>=10.2.2,!=10.15.*, <=12.0.0 diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index ab8fa6e7ed93b..8f136229a9c14 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -20,7 +20,7 @@ from argparse import Namespace from contextlib import contextmanager, ExitStack, redirect_stdout from io import StringIO -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union from unittest import mock from unittest.mock import ANY @@ -1561,3 +1561,16 @@ def test_cli_auto_seeding(): cli = LightningCLI(TestModel, run=False) assert cli.seed_everything_default is True assert cli.config["seed_everything"] == 123 # the original seed is kept + + +def test_unresolvable_import_paths(): + class TestModel(BoringModel): + def __init__(self, a_func: Callable = torch.softmax): + super().__init__() + self.a_func = a_func + + out = StringIO() + with mock.patch("sys.argv", ["any.py", "--print_config"]), redirect_stdout(out), pytest.raises(SystemExit): + LightningCLI(TestModel, run=False) + + assert "a_func: torch.softmax" in out.getvalue()