Skip to content

Commit

Permalink
Restructure parsing flow in the LightningCLI (Lightning-AI#8721)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and four4fish committed Aug 16, 2021
1 parent 78bcadc commit 87cad8f
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 69 deletions.
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `trainer.lightning_module` reference is now properly set at the very beginning of the run ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))


- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))
- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))


- The `Trainer` functions `reset_{train,val,test,predict}_dataloader`, `reset_train_val_dataloaders`, and `request_dataloader` `model` argument is now optional ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
Expand All @@ -54,6 +54,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Improved string conversion for `ResultCollection` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))


- `LightningCLI` changes:
* `LightningCLI.init_parser` now returns the parser instance. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))
* `LightningCLI.add_core_arguments_to_parser`, `LightningCLI.parse_arguments` now take a `parser` argument. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))
* `LightningCLI.instantiate_trainer` now takes a config and a list of callbacks. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))
* Split `LightningCLI.add_core_arguments_to_parser` into `LightningCLI.add_default_arguments_to_parser` + `LightningCLI.add_core_arguments_to_parser`. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721))


- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))

- Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608]
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,8 @@ def test(
test_dataloaders=None, # TODO: remove with 1.6
) -> _EVALUATE_OUTPUT:
r"""
Perform one evaluation epoch over the test set. It's separated from
fit to make sure you never run on your test set until you want to.
Perform one evaluation epoch over the test set.
It's separated from fit to make sure you never run on your test set until you want to.
Args:
model: The model to test.
Expand Down Expand Up @@ -710,9 +710,9 @@ def predict(
ckpt_path: Optional[str] = None,
) -> Optional[_PREDICT_OUTPUT]:
r"""
Separates from fit to make sure you never run on your predictions set until you want to.
This will call the model forward function to compute predictions.
Run inference on your data.
This will call the model forward function to compute predictions. Useful to perform distributed
and batched predictions. Logging is disabled in the predict hooks.
Args:
model: The model to predict with.
Expand Down
127 changes: 71 additions & 56 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,17 @@
# limitations under the License.
import inspect
import os
import warnings
from argparse import Namespace
from types import MethodType
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union

from torch.optim import Optimizer

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, warnings
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple

if _JSONARGPARSE_AVAILABLE:
Expand Down Expand Up @@ -79,6 +74,9 @@ def add_lightning_class_args(
lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}.
nested_key: Name of the nested namespace to store arguments.
subclass_mode: Whether allow any subclass of the given class.
Returns:
A list with the names of the class arguments added.
"""
if callable(lightning_class) and not inspect.isclass(lightning_class):
lightning_class = class_from_function(lightning_class)
Expand Down Expand Up @@ -191,7 +189,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st

def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]:
# `ArgumentParser` is un-pickleable. Drop it
return (self.__class__, (None, self.config, self.config_filename), {})
return self.__class__, (None, self.config, self.config_filename), {}


class LightningCLI:
Expand All @@ -205,21 +203,22 @@ def __init__(
save_config_filename: str = "config.yaml",
save_config_overwrite: bool = False,
trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer,
trainer_defaults: Dict[str, Any] = None,
seed_everything_default: int = None,
trainer_defaults: Optional[Dict[str, Any]] = None,
seed_everything_default: Optional[int] = None,
description: str = "pytorch-lightning trainer command line tool",
env_prefix: str = "PL",
env_parse: bool = False,
parser_kwargs: Dict[str, Any] = None,
parser_kwargs: Optional[Dict[str, Any]] = None,
subclass_mode_model: bool = False,
subclass_mode_data: bool = False,
) -> None:
"""
Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
called / instantiated using a parsed configuration file and / or command line args and then runs trainer.fit.
Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. A full
configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from variables
named for example ``PL_TRAINER__MAX_EPOCHS``.
called / instantiated using a parsed configuration file and / or command line args.
Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``.
A full configuration yaml would be parsed from ``PL_CONFIG`` if set.
Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``.
Example, first implement the ``trainer.py`` tool as::
Expand Down Expand Up @@ -266,56 +265,73 @@ def __init__(
self.save_config_filename = save_config_filename
self.save_config_overwrite = save_config_overwrite
self.trainer_class = trainer_class
self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults
self.trainer_defaults = trainer_defaults or {}
self.seed_everything_default = seed_everything_default
self.subclass_mode_model = subclass_mode_model
self.subclass_mode_data = subclass_mode_data
self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs
self.parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse})

self.init_parser()
self.add_core_arguments_to_parser()
self.add_arguments_to_parser(self.parser)
parser_kwargs = parser_kwargs or {}
parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse})
self.setup_parser(**parser_kwargs)
self.link_optimizers_and_lr_schedulers()
self.parse_arguments()
if self.config["seed_everything"] is not None:
seed_everything(self.config["seed_everything"], workers=True)
self.parse_arguments(self.parser)

seed = self.config.get("seed_everything")
if seed is not None:
seed_everything(seed, workers=True)

self.before_instantiate_classes()
self.instantiate_classes()
self.add_configure_optimizers_method_to_model()

self.prepare_fit_kwargs()
self.before_fit()
self.fit()
self.after_fit()

def init_parser(self) -> None:
"""Method that instantiates the argument parser"""
self.parser = LightningArgumentParser(**self.parser_kwargs)
def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
"""Method that instantiates the argument parser."""
return LightningArgumentParser(**kwargs)

def setup_parser(self, **kwargs: Any) -> None:
"""Initialize and setup the parser, and arguments."""
self.parser = self.init_parser(**kwargs)
self._add_arguments(self.parser)

def add_core_arguments_to_parser(self) -> None:
"""Adds arguments from the core classes to the parser"""
self.parser.add_argument(
def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"""Adds default arguments to the parser."""
parser.add_argument(
"--seed_everything",
type=Optional[int],
default=self.seed_everything_default,
help="Set to an int to run seed_everything with this value before classes instantiation",
)
self.parser.add_lightning_class_args(self.trainer_class, "trainer")

def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"""Adds arguments from the core classes to the parser."""
parser.add_lightning_class_args(self.trainer_class, "trainer")
trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"}
self.parser.set_defaults(trainer_defaults)
self.parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)
parser.set_defaults(trainer_defaults)
parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model)
if self.datamodule_class is not None:
self.parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data)
parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data)

def _add_arguments(self, parser: LightningArgumentParser) -> None:
# default + core + custom arguments
self.add_default_arguments_to_parser(parser)
self.add_core_arguments_to_parser(parser)
self.add_arguments_to_parser(parser)

def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
"""Implement to add extra arguments to parser or link arguments
"""
Implement to add extra arguments to the parser or link arguments.
Args:
parser: The argument parser object to which arguments can be added
parser: The parser object to which arguments can be added
"""

def link_optimizers_and_lr_schedulers(self) -> None:
"""Creates argument links for optimizers and lr_schedulers that specified a link_to"""
"""Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``."""
for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items():
if link_to == "AUTOMATIC":
continue
Expand All @@ -325,41 +341,40 @@ def link_optimizers_and_lr_schedulers(self) -> None:
add_class_path = _add_class_path_generator(class_type)
self.parser.link_arguments(key, link_to, compute_fn=add_class_path)

def parse_arguments(self) -> None:
"""Parses command line arguments and stores it in self.config"""
self.config = self.parser.parse_args()
def parse_arguments(self, parser: LightningArgumentParser) -> None:
"""Parses command line arguments and stores it in ``self.config``."""
self.config = parser.parse_args()

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes"""
"""Implement to run some code before instantiating the classes."""

def instantiate_classes(self) -> None:
"""Instantiates the classes using settings from self.config"""
"""Instantiates the classes and sets their attributes."""
self.config_init = self.parser.instantiate_classes(self.config)
self.datamodule = self.config_init.get("data")
self.model = self.config_init["model"]
self.instantiate_trainer()

def instantiate_trainer(self) -> None:
"""Instantiates the trainer using self.config_init['trainer']"""
if self.config_init["trainer"].get("callbacks") is None:
self.config_init["trainer"]["callbacks"] = []
callbacks = [self.config_init[c] for c in self.parser.callback_keys]
self.config_init["trainer"]["callbacks"].extend(callbacks)
self.trainer = self.instantiate_trainer(self.config_init["trainer"], callbacks)

def instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer:
"""Instantiates the trainer."""
config["callbacks"] = config["callbacks"] or []
config["callbacks"].extend(callbacks)
if "callbacks" in self.trainer_defaults:
if isinstance(self.trainer_defaults["callbacks"], list):
self.config_init["trainer"]["callbacks"].extend(self.trainer_defaults["callbacks"])
config["callbacks"].extend(self.trainer_defaults["callbacks"])
else:
self.config_init["trainer"]["callbacks"].append(self.trainer_defaults["callbacks"])
if self.save_config_callback and not self.config_init["trainer"]["fast_dev_run"]:
config["callbacks"].append(self.trainer_defaults["callbacks"])
if self.save_config_callback and not config["fast_dev_run"]:
config_callback = self.save_config_callback(
self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite
)
self.config_init["trainer"]["callbacks"].append(config_callback)
self.trainer = self.trainer_class(**self.config_init["trainer"])
config["callbacks"].append(config_callback)
return self.trainer_class(**config)

def add_configure_optimizers_method_to_model(self) -> None:
"""
Adds to the model an automatically generated configure_optimizers method
Adds to the model an automatically generated ``configure_optimizers`` method.
If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC',
then a `configure_optimizers` method is automatically implemented in the model class.
Expand Down Expand Up @@ -390,7 +405,7 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]:
)

if is_overridden("configure_optimizers", self.model):
warnings.warn(
warnings._warn(
f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "
f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`."
)
Expand Down
14 changes: 7 additions & 7 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,17 +402,18 @@ def test_lightning_cli_help():
out = StringIO()
with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit):
any_model_any_data_cli()
out = out.getvalue()

assert "--print_config" in out.getvalue()
assert "--config" in out.getvalue()
assert "--seed_everything" in out.getvalue()
assert "--model.help" in out.getvalue()
assert "--data.help" in out.getvalue()
assert "--print_config" in out
assert "--config" in out
assert "--seed_everything" in out
assert "--model.help" in out
assert "--data.help" in out

skip_params = {"self"}
for param in inspect.signature(Trainer.__init__).parameters.keys():
if param not in skip_params:
assert f"--trainer.{param}" in out.getvalue()
assert f"--trainer.{param}" in out

cli_args = ["any.py", "--data.help=tests.helpers.BoringDataModule"]
out = StringIO()
Expand All @@ -423,7 +424,6 @@ def test_lightning_cli_help():


def test_lightning_cli_print_config():

cli_args = [
"any.py",
"--seed_everything=1234",
Expand Down

0 comments on commit 87cad8f

Please sign in to comment.