Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed issue where the CLI fails with certain torch objects ([#13153](https://github.com/PyTorchLightning/pytorch-lightning/pull/13153))


- Fixed ``LightningCLI`` signature parameter resolving for some lightning classes ([#13283](https://github.com/PyTorchLightning/pytorch-lightning/pull/13283))


- Fixed `estimated_stepping_batches` requiring distributed comms in `configure_optimizers` for the `DeepSpeedStrategy` ([#13350](https://github.com/PyTorchLightning/pytorch-lightning/pull/13350))


Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ matplotlib>3.1, <3.5.3
torchtext>=0.10.*, <=0.12.0
omegaconf>=2.0.5, <=2.1.*
hydra-core>=1.0.5, <=1.1.*
jsonargparse[signatures]>=4.9.0, <=4.9.0
jsonargparse[signatures]>=4.10.0, <=4.10.0
gcsfs>=2021.5.0, <=2022.2.0
rich>=10.2.2, !=10.15.0.a, <13.0.0
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn

_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.9.0")
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.0")

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
Expand Down
82 changes: 79 additions & 3 deletions tests/tests_pytorch/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,14 @@
from pytorch_lightning import __version__, Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import _COMET_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE, TensorBoardLogger
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.profiler import PyTorchProfiler
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import (
_JSONARGPARSE_SIGNATURES_AVAILABLE,
instantiate_class,
LightningArgumentParser,
LightningCLI,
Expand All @@ -54,6 +57,9 @@
if _TORCHVISION_AVAILABLE:
torchvision_version = version.parse(__import__("torchvision").__version__)

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
from jsonargparse import lazy_instance


@contextmanager
def mock_subclasses(baseclass, *subclasses):
Expand Down Expand Up @@ -1350,8 +1356,6 @@ def configure_optimizers(self, optimizer, lr_scheduler=None):


def test_cli_parameter_with_lazy_instance_default():
from jsonargparse import lazy_instance

class TestModel(BoringModel):
def __init__(self, activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReLU, negative_slope=0.05)):
super().__init__()
Expand All @@ -1367,6 +1371,21 @@ def __init__(self, activation: torch.nn.Module = lazy_instance(torch.nn.LeakyReL
assert cli.model.activation is not model.activation


def test_ddpstrategy_instantiation_and_find_unused_parameters():
strategy_default = lazy_instance(DDPStrategy, find_unused_parameters=True)
with mock.patch("sys.argv", ["any.py", "--trainer.strategy.process_group_backend=group"]):
cli = LightningCLI(
BoringModel,
run=False,
trainer_defaults={"strategy": strategy_default},
)

assert cli.config.trainer.strategy.init_args.find_unused_parameters is True
assert isinstance(cli.config_init.trainer.strategy, DDPStrategy)
assert cli.config_init.trainer.strategy.process_group_backend == "group"
assert strategy_default is not cli.config_init.trainer.strategy


def test_cli_logger_shorthand():
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(TestModel, run=False, trainer_defaults={"logger": False})
Expand All @@ -1381,6 +1400,41 @@ def test_cli_logger_shorthand():
assert cli.trainer.logger is None


def _test_logger_init_args(logger_name, init, unresolved={}):
cli_args = [f"--trainer.logger={logger_name}"]
cli_args += [f"--trainer.logger.{k}={v}" for k, v in init.items()]
cli_args += [f"--trainer.logger.dict_kwargs.{k}={v}" for k, v in unresolved.items()]
cli_args.append("--print_config")

out = StringIO()
with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stdout(out), pytest.raises(SystemExit):
LightningCLI(TestModel, run=False)

data = yaml.safe_load(out.getvalue())["trainer"]["logger"]
assert {k: data["init_args"][k] for k in init} == init
if unresolved:
assert data["dict_kwargs"] == unresolved


@pytest.mark.skipif(not _COMET_AVAILABLE, reason="comet-ml is required")
def test_comet_logger_init_args():
_test_logger_init_args("CometLogger", {"save_dir": "comet", "workspace": "comet"})


@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="neptune-client is required")
def test_neptune_logger_init_args():
_test_logger_init_args("NeptuneLogger", {"name": "neptune"}, {"description": "neptune"})


def test_tensorboard_logger_init_args():
_test_logger_init_args("TensorBoardLogger", {"save_dir": "tb", "name": "tb"})


@pytest.mark.skipif(not _WANDB_AVAILABLE, reason="wandb is required")
def test_wandb_logger_init_args():
_test_logger_init_args("WandbLogger", {"save_dir": "wandb", "notes": "wandb"})


def test_cli_auto_seeding():
with mock.patch("sys.argv", ["any.py"]):
cli = LightningCLI(TestModel, run=False, seed_everything_default=False)
Expand Down Expand Up @@ -1440,3 +1494,25 @@ def __init__(self, a_func: Callable = torch.softmax):
LightningCLI(TestModel, run=False)

assert "a_func: torch.softmax" in out.getvalue()


def test_pytorch_profiler_init_args():
init = {
"dirpath": "profiler",
"row_limit": 10,
"group_by_input_shapes": True,
}
unresolved = {
"profile_memory": True,
"record_shapes": True,
}
cli_args = ["--trainer.profiler=PyTorchProfiler"]
cli_args += [f"--trainer.profiler.{k}={v}" for k, v in init.items()]
cli_args += [f"--trainer.profiler.dict_kwargs.{k}={v}" for k, v in unresolved.items()]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = LightningCLI(TestModel, run=False)

assert isinstance(cli.config_init.trainer.profiler, PyTorchProfiler)
assert {k: cli.config.trainer.profiler.init_args[k] for k in init} == init
assert cli.config.trainer.profiler.dict_kwargs == unresolved