diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f66cb1539625c..3ac97333e6d90e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The `trainer.lightning_module` reference is now properly set at the very beginning of the run ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536)) -- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352))) +- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)) - The `Trainer` functions `reset_{train,val,test,predict}_dataloader`, `reset_train_val_dataloaders`, and `request_dataloader` `model` argument is now optional ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536)) @@ -54,6 +54,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Improved string conversion for `ResultCollection` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622)) +- `LightningCLI` changes: + * `LightningCLI.init_parser` now returns the parser instance. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721)) + * `LightningCLI.add_core_arguments_to_parser`, `LightningCLI.parse_arguments` now take a `parser` argument. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721)) + * `LightningCLI.instantiate_trainer` now takes a config and a list of callbacks. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721)) + * Split `LightningCLI.add_core_arguments_to_parser` into `LightningCLI.add_default_arguments_to_parser` + `LightningCLI.add_core_arguments_to_parser`. ([#8721](https://github.com/PyTorchLightning/pytorch-lightning/pull/8721)) + + - The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536)) - Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 07c8d45601b31c..4201724e08e4ca 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -632,8 +632,8 @@ def test( test_dataloaders=None, # TODO: remove with 1.6 ) -> _EVALUATE_OUTPUT: r""" - Perform one evaluation epoch over the test set. It's separated from - fit to make sure you never run on your test set until you want to. + Perform one evaluation epoch over the test set. + It's separated from fit to make sure you never run on your test set until you want to. Args: model: The model to test. @@ -710,9 +710,9 @@ def predict( ckpt_path: Optional[str] = None, ) -> Optional[_PREDICT_OUTPUT]: r""" - - Separates from fit to make sure you never run on your predictions set until you want to. - This will call the model forward function to compute predictions. + Run inference on your data. + This will call the model forward function to compute predictions. Useful to perform distributed + and batched predictions. Logging is disabled in the predict hooks. Args: model: The model to predict with. diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 93dbc57a193782..752d0092d1baf1 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -13,22 +13,17 @@ # limitations under the License. import inspect import os -import warnings from argparse import Namespace from types import MethodType from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union from torch.optim import Optimizer -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.trainer.trainer import Trainer +from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer +from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, warnings from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.types import LRSchedulerType, LRSchedulerTypeTuple if _JSONARGPARSE_AVAILABLE: @@ -79,6 +74,9 @@ def add_lightning_class_args( lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. nested_key: Name of the nested namespace to store arguments. subclass_mode: Whether allow any subclass of the given class. + + Returns: + A list with the names of the class arguments added. """ if callable(lightning_class) and not inspect.isclass(lightning_class): lightning_class = class_from_function(lightning_class) @@ -191,7 +189,7 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]: # `ArgumentParser` is un-pickleable. Drop it - return (self.__class__, (None, self.config, self.config_filename), {}) + return self.__class__, (None, self.config, self.config_filename), {} class LightningCLI: @@ -205,21 +203,22 @@ def __init__( save_config_filename: str = "config.yaml", save_config_overwrite: bool = False, trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, - trainer_defaults: Dict[str, Any] = None, - seed_everything_default: int = None, + trainer_defaults: Optional[Dict[str, Any]] = None, + seed_everything_default: Optional[int] = None, description: str = "pytorch-lightning trainer command line tool", env_prefix: str = "PL", env_parse: bool = False, - parser_kwargs: Dict[str, Any] = None, + parser_kwargs: Optional[Dict[str, Any]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, ) -> None: """ Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are - called / instantiated using a parsed configuration file and / or command line args and then runs trainer.fit. - Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. A full - configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from variables - named for example ``PL_TRAINER__MAX_EPOCHS``. + called / instantiated using a parsed configuration file and / or command line args. + + Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. + A full configuration yaml would be parsed from ``PL_CONFIG`` if set. + Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``. Example, first implement the ``trainer.py`` tool as:: @@ -266,56 +265,73 @@ def __init__( self.save_config_filename = save_config_filename self.save_config_overwrite = save_config_overwrite self.trainer_class = trainer_class - self.trainer_defaults = {} if trainer_defaults is None else trainer_defaults + self.trainer_defaults = trainer_defaults or {} self.seed_everything_default = seed_everything_default self.subclass_mode_model = subclass_mode_model self.subclass_mode_data = subclass_mode_data - self.parser_kwargs = {} if parser_kwargs is None else parser_kwargs - self.parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse}) - self.init_parser() - self.add_core_arguments_to_parser() - self.add_arguments_to_parser(self.parser) + parser_kwargs = parser_kwargs or {} + parser_kwargs.update({"description": description, "env_prefix": env_prefix, "default_env": env_parse}) + self.setup_parser(**parser_kwargs) self.link_optimizers_and_lr_schedulers() - self.parse_arguments() - if self.config["seed_everything"] is not None: - seed_everything(self.config["seed_everything"], workers=True) + self.parse_arguments(self.parser) + + seed = self.config.get("seed_everything") + if seed is not None: + seed_everything(seed, workers=True) + self.before_instantiate_classes() self.instantiate_classes() self.add_configure_optimizers_method_to_model() + self.prepare_fit_kwargs() self.before_fit() self.fit() self.after_fit() - def init_parser(self) -> None: - """Method that instantiates the argument parser""" - self.parser = LightningArgumentParser(**self.parser_kwargs) + def init_parser(self, **kwargs: Any) -> LightningArgumentParser: + """Method that instantiates the argument parser.""" + return LightningArgumentParser(**kwargs) + + def setup_parser(self, **kwargs: Any) -> None: + """Initialize and setup the parser, and arguments.""" + self.parser = self.init_parser(**kwargs) + self._add_arguments(self.parser) - def add_core_arguments_to_parser(self) -> None: - """Adds arguments from the core classes to the parser""" - self.parser.add_argument( + def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Adds default arguments to the parser.""" + parser.add_argument( "--seed_everything", type=Optional[int], default=self.seed_everything_default, help="Set to an int to run seed_everything with this value before classes instantiation", ) - self.parser.add_lightning_class_args(self.trainer_class, "trainer") + + def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: + """Adds arguments from the core classes to the parser.""" + parser.add_lightning_class_args(self.trainer_class, "trainer") trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"} - self.parser.set_defaults(trainer_defaults) - self.parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model) + parser.set_defaults(trainer_defaults) + parser.add_lightning_class_args(self.model_class, "model", subclass_mode=self.subclass_mode_model) if self.datamodule_class is not None: - self.parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data) + parser.add_lightning_class_args(self.datamodule_class, "data", subclass_mode=self.subclass_mode_data) + + def _add_arguments(self, parser: LightningArgumentParser) -> None: + # default + core + custom arguments + self.add_default_arguments_to_parser(parser) + self.add_core_arguments_to_parser(parser) + self.add_arguments_to_parser(parser) def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: - """Implement to add extra arguments to parser or link arguments + """ + Implement to add extra arguments to the parser or link arguments. Args: - parser: The argument parser object to which arguments can be added + parser: The parser object to which arguments can be added """ def link_optimizers_and_lr_schedulers(self) -> None: - """Creates argument links for optimizers and lr_schedulers that specified a link_to""" + """Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``.""" for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items(): if link_to == "AUTOMATIC": continue @@ -325,41 +341,40 @@ def link_optimizers_and_lr_schedulers(self) -> None: add_class_path = _add_class_path_generator(class_type) self.parser.link_arguments(key, link_to, compute_fn=add_class_path) - def parse_arguments(self) -> None: - """Parses command line arguments and stores it in self.config""" - self.config = self.parser.parse_args() + def parse_arguments(self, parser: LightningArgumentParser) -> None: + """Parses command line arguments and stores it in ``self.config``.""" + self.config = parser.parse_args() def before_instantiate_classes(self) -> None: - """Implement to run some code before instantiating the classes""" + """Implement to run some code before instantiating the classes.""" def instantiate_classes(self) -> None: - """Instantiates the classes using settings from self.config""" + """Instantiates the classes and sets their attributes.""" self.config_init = self.parser.instantiate_classes(self.config) self.datamodule = self.config_init.get("data") self.model = self.config_init["model"] - self.instantiate_trainer() - - def instantiate_trainer(self) -> None: - """Instantiates the trainer using self.config_init['trainer']""" - if self.config_init["trainer"].get("callbacks") is None: - self.config_init["trainer"]["callbacks"] = [] callbacks = [self.config_init[c] for c in self.parser.callback_keys] - self.config_init["trainer"]["callbacks"].extend(callbacks) + self.trainer = self.instantiate_trainer(self.config_init["trainer"], callbacks) + + def instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: + """Instantiates the trainer.""" + config["callbacks"] = config["callbacks"] or [] + config["callbacks"].extend(callbacks) if "callbacks" in self.trainer_defaults: if isinstance(self.trainer_defaults["callbacks"], list): - self.config_init["trainer"]["callbacks"].extend(self.trainer_defaults["callbacks"]) + config["callbacks"].extend(self.trainer_defaults["callbacks"]) else: - self.config_init["trainer"]["callbacks"].append(self.trainer_defaults["callbacks"]) - if self.save_config_callback and not self.config_init["trainer"]["fast_dev_run"]: + config["callbacks"].append(self.trainer_defaults["callbacks"]) + if self.save_config_callback and not config["fast_dev_run"]: config_callback = self.save_config_callback( self.parser, self.config, self.save_config_filename, overwrite=self.save_config_overwrite ) - self.config_init["trainer"]["callbacks"].append(config_callback) - self.trainer = self.trainer_class(**self.config_init["trainer"]) + config["callbacks"].append(config_callback) + return self.trainer_class(**config) def add_configure_optimizers_method_to_model(self) -> None: """ - Adds to the model an automatically generated configure_optimizers method + Adds to the model an automatically generated ``configure_optimizers`` method. If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC', then a `configure_optimizers` method is automatically implemented in the model class. @@ -390,7 +405,7 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]: ) if is_overridden("configure_optimizers", self.model): - warnings.warn( + warnings._warn( f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model`." ) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index b80c371eb59300..cc636aa9a17ed7 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -402,17 +402,18 @@ def test_lightning_cli_help(): out = StringIO() with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): any_model_any_data_cli() + out = out.getvalue() - assert "--print_config" in out.getvalue() - assert "--config" in out.getvalue() - assert "--seed_everything" in out.getvalue() - assert "--model.help" in out.getvalue() - assert "--data.help" in out.getvalue() + assert "--print_config" in out + assert "--config" in out + assert "--seed_everything" in out + assert "--model.help" in out + assert "--data.help" in out skip_params = {"self"} for param in inspect.signature(Trainer.__init__).parameters.keys(): if param not in skip_params: - assert f"--trainer.{param}" in out.getvalue() + assert f"--trainer.{param}" in out cli_args = ["any.py", "--data.help=tests.helpers.BoringDataModule"] out = StringIO() @@ -423,7 +424,6 @@ def test_lightning_cli_help(): def test_lightning_cli_print_config(): - cli_args = [ "any.py", "--seed_everything=1234",