Skip to content

Commit

Permalink
[CLI] Shorthand notation to instantiate optimizers and lr schedulers …
Browse files Browse the repository at this point in the history
…[2/3] (#9565)
  • Loading branch information
carmocca authored Sep 17, 2021
1 parent 77c719f commit bbcb977
Show file tree
Hide file tree
Showing 5 changed files with 443 additions and 51 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))
* Added support to call any trainer function from the `LightningCLI` via subcommands ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/7508))
* Allow easy trainer re-instantiation ([#7508](https://github.com/PyTorchLightning/pytorch-lightning/pull/9241))
* Automatically register all optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Allow registering custom optimizers and learning rate schedulers without subclassing the CLI ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))
* Support shorthand notation to instantiate optimizers and learning rate schedulers ([#9565](https://github.com/PyTorchLightning/pytorch-lightning/pull/9565))


- Fault-tolerant training:
Expand Down
139 changes: 102 additions & 37 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -665,69 +665,135 @@ Optimizers and learning rate schedulers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Optimizers and learning rate schedulers can also be made configurable. The most common case is when a model only has a
single optimizer and optionally a single learning rate scheduler. In this case the model's
:class:`~pytorch_lightning.core.lightning.LightningModule` could be left without implementing the
:code:`configure_optimizers` method since it is normally always the same and just adds boilerplate. The following code
snippet shows how to implement it:
single optimizer and optionally a single learning rate scheduler. In this case, the model's
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` could be left unimplemented since it is
normally always the same and just adds boilerplate.
.. testcode::
import torch
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
The CLI works out-of-the-box with PyTorch's built-in optimizers and learning rate schedulers when
at most one of each is used.
Only the optimizer or scheduler name needs to be passed, optionally with its ``__init__`` arguments:
.. code-block:: bash
cli = MyLightningCLI(MyModel)
$ python trainer.py fit --optimizer=Adam --optimizer.lr=0.01 --lr_scheduler=ExponentialLR --lr_scheduler.gamma=0.1
With this the :code:`configure_optimizers` method is automatically implemented and in the config the :code:`optimizer`
and :code:`lr_scheduler` groups would accept all of the options for the given classes, in this example :code:`Adam` and
:code:`ExponentialLR`. Therefore, the config file would be structured like:
A corresponding example of the config file would be:
.. code-block:: yaml
optimizer:
lr: 0.01
class_path: torch.optim.Adam
init_args:
lr: 0.01
lr_scheduler:
gamma: 0.2
class_path: torch.optim.lr_scheduler.ExponentialLR
init_args:
gamma: 0.1
model:
...
trainer:
...
And any of these arguments could be passed directly through command line. For example:
.. note::
This short-hand notation is only supported in the shell and not inside a configuration file. The configuration file
generated by calling the previous command with ``--print_config`` will have the ``class_path`` notation.
Furthermore, you can register your own optimizers and/or learning rate schedulers as follows:
.. code-block:: python
from pytorch_lightning.utilities.cli import OPTIMIZER_REGISTRY, LR_SCHEDULER_REGISTRY
@OPTIMIZER_REGISTRY
class CustomAdam(torch.optim.Adam):
...
@LR_SCHEDULER_REGISTRY
class CustomCosineAnnealingLR(torch.optim.lr_scheduler.CosineAnnealingLR):
...
# register all `Optimizer` subclasses from the `torch.optim` package
# This is done automatically!
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)
cli = LightningCLI(...)
.. code-block:: bash
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
$ python trainer.py fit --optimizer=CustomAdam --optimizer.lr=0.01 --lr_scheduler=CustomCosineAnnealingLR
If you need to customize the key names or link arguments together, you can choose from all available optimizers and
learning rate schedulers by accessing the registries.
.. code-block::
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes,
nested_key="gen_optimizer",
link_to="model.optimizer1_init"
)
parser.add_optimizer_args(
OPTIMIZER_REGISTRY.classes,
nested_key="gen_discriminator",
link_to="model.optimizer2_init"
)
.. code-block:: bash
$ python trainer.py fit \
--gen_optimizer=Adam \
--gen_optimizer.lr=0.01 \
--gen_discriminator=AdamW \
--gen_discriminator.lr=0.0001
You can also use pass the class path directly, for example, if the optimizer hasn't been registered to the
``OPTIMIZER_REGISTRY``:
.. code-block:: bash
$ python trainer.py fit \
--gen_optimizer.class_path=torch.optim.Adam \
--gen_optimizer.init_args.lr=0.01 \
--gen_discriminator.class_path=torch.optim.AdamW \
--gen_discriminator.init_args.lr=0.0001
There is also the possibility of selecting among multiple classes by giving them as a tuple. For example:
If you will not be changing the class, you can manually add the arguments for specific optimizers and/or
learning rate schedulers by subclassing the CLI. This has the advantage of providing the proper help message for those
classes. The following code snippet shows how to implement it:
.. testcode::
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam))
parser.add_optimizer_args(torch.optim.Adam)
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
In this case in the config the :code:`optimizer` group instead of having directly init settings, it should specify
:code:`class_path` and optionally :code:`init_args`. Sub-classes of the classes in the tuple would also be accepted.
A corresponding example of the config file would be:
With this, in the config the :code:`optimizer` and :code:`lr_scheduler` groups would accept all of the options for the
given classes, in this example :code:`Adam` and :code:`ExponentialLR`.
Therefore, the config file would be structured like:
.. code-block:: yaml
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.01
lr: 0.01
lr_scheduler:
gamma: 0.2
model:
...
trainer:
...
And the same through command line:
Where the arguments can be passed directly through command line without specifying the class. For example:
.. code-block:: bash
$ python trainer.py fit --optimizer.class_path=torch.optim.Adam --optimizer.init_args.lr=0.01
$ python trainer.py fit --optimizer.lr=0.01 --lr_scheduler.gamma=0.2
The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An
example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be:
Expand Down Expand Up @@ -763,12 +829,11 @@ example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. Th
cli = MyLightningCLI(MyModel)
For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with
a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including
:code:`class_path` and :code:`init_args` entries. The function
:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in
:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the
:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.
The value given to :code:`optimizer_init` will always be a dictionary including :code:`class_path` and
:code:`init_args` entries. The function :func:`~pytorch_lightning.utilities.cli.instantiate_class`
takes care of importing the class defined in :code:`class_path` and instantiating it using some positional arguments,
in this case :code:`self.parameters()`, and the :code:`init_args`.
Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`.
Notes related to reproducibility
Expand Down
120 changes: 119 additions & 1 deletion pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
import sys
from argparse import Namespace
from types import MethodType
from types import MethodType, ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from unittest import mock

import torch
from torch.optim import Optimizer

from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
Expand All @@ -35,9 +39,57 @@
ArgumentParser = object


class _Registry(dict):
def __call__(self, cls: Type, key: Optional[str] = None, override: bool = False) -> None:
"""Registers a class mapped to a name.
Args:
cls: the class to be mapped.
key: the name that identifies the provided class.
override: Whether to override an existing key.
"""
if key is None:
key = cls.__name__
elif not isinstance(key, str):
raise TypeError(f"`key` must be a str, found {key}")

if key in self and not override:
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.")
self[key] = cls

def register_classes(self, module: ModuleType, base_cls: Type, override: bool = False) -> None:
"""This function is an utility to register all classes from a module."""
for _, cls in inspect.getmembers(module, predicate=inspect.isclass):
if issubclass(cls, base_cls) and cls != base_cls:
self(cls=cls, override=override)

@property
def names(self) -> List[str]:
"""Returns the registered names."""
return list(self.keys())

@property
def classes(self) -> Tuple[Type, ...]:
"""Returns the registered classes."""
return tuple(self.values())

def __str__(self) -> str:
return f"Registered objects: {self.names}"


OPTIMIZER_REGISTRY = _Registry()
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer)

LR_SCHEDULER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""

# use class attribute because `parse_args` is only called on the main parser
_choices: Dict[str, Tuple[Type, ...]] = {}

def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.
Expand Down Expand Up @@ -118,6 +170,7 @@ def add_optimizer_args(
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
if isinstance(optimizer_class, tuple):
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
self.set_choices(nested_key, optimizer_class)
else:
self.add_class_arguments(optimizer_class, nested_key, **kwargs)
self._optimizers[nested_key] = (optimizer_class, link_to)
Expand All @@ -142,10 +195,70 @@ def add_lr_scheduler_args(
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
if isinstance(lr_scheduler_class, tuple):
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
self.set_choices(nested_key, lr_scheduler_class)
else:
self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)

def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
argv = sys.argv
for k, classes in self._choices.items():
if not any(arg.startswith(f"--{k}") for arg in argv):
# the key wasn't passed - maybe defined in a config, maybe it's optional
continue
argv = self._convert_argv_issue_84(classes, k, argv)
self._choices.clear()
with mock.patch("sys.argv", argv):
return super().parse_args(*args, **kwargs)

def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None:
self._choices[nested_key] = classes

@staticmethod
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.
This should be removed once implemented.
"""
passed_args, clean_argv = {}, []
argv_key = f"--{nested_key}"
# get the argv args for this nested key
i = 0
while i < len(argv):
arg = argv[i]
if arg.startswith(argv_key):
if "=" in arg:
key, value = arg.split("=")
else:
key = arg
i += 1
value = argv[i]
passed_args[key] = value
else:
clean_argv.append(arg)
i += 1
# generate the associated config file
argv_class = passed_args.pop(argv_key, None)
if argv_class is None:
# the user passed a config as a str
class_path = passed_args[f"{argv_key}.class_path"]
init_args_key = f"{argv_key}.init_args"
init_args = {k[len(init_args_key) + 1 :]: v for k, v in passed_args.items() if k.startswith(init_args_key)}
config = str({"class_path": class_path, "init_args": init_args})
elif argv_class.startswith("{"):
# the user passed a config as a dict
config = argv_class
else:
# the user passed the shorthand format
init_args = {k[len(argv_key) + 1 :]: v for k, v in passed_args.items()} # +1 to account for the period
for cls in classes:
if cls.__name__ == argv_class:
config = str(_global_add_class_path(cls, init_args))
break
else:
raise ValueError(f"Could not generate a config for {repr(argv_class)}")
return clean_argv + [argv_key, config]


class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts.
Expand Down Expand Up @@ -328,6 +441,11 @@ def _add_arguments(self, parser: LightningArgumentParser) -> None:
self.add_default_arguments_to_parser(parser)
self.add_core_arguments_to_parser(parser)
self.add_arguments_to_parser(parser)
# add default optimizer args if necessary
if not parser._optimizers: # already added by the user in `add_arguments_to_parser`
parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes)
if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser`
parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes)
self.link_optimizers_and_lr_schedulers(parser)

def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ torchtext>=0.7
onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
jsonargparse[signatures]>=3.19.0
jsonargparse[signatures]>=3.19.3
gcsfs>=2021.5.0
rich>=10.2.2
Loading

0 comments on commit bbcb977

Please sign in to comment.