Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLI] Shorthand notation to instantiate callbacks [3/3] #8815

Merged
merged 88 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
d7f00be
add registries
tchaton Aug 9, 2021
a07d305
simplify LightningCLI with defaults
tchaton Aug 9, 2021
ce39c47
cleanup
tchaton Aug 9, 2021
3081475
update
tchaton Aug 9, 2021
51f82d5
updates
tchaton Aug 10, 2021
7197d6e
cleanup
tchaton Aug 10, 2021
9a6e81e
update on comments
tchaton Aug 10, 2021
41f5d78
update
tchaton Aug 10, 2021
06e4999
cleanup
tchaton Aug 10, 2021
e91ea47
update on comments
tchaton Aug 10, 2021
3f35ecd
Merge branch 'master' into lightning_cli_registries
tchaton Aug 10, 2021
705c0bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2021
78a4398
add docs
tchaton Aug 10, 2021
e96dc28
doc updates
tchaton Aug 10, 2021
631aa72
update
tchaton Aug 10, 2021
2fc3c0a
update
tchaton Aug 10, 2021
43dd8b4
resolve comments
tchaton Aug 10, 2021
c6ae669
comment
tchaton Aug 10, 2021
5c21b1c
add comment
tchaton Aug 10, 2021
b370deb
typo
tchaton Aug 10, 2021
f8e7ca7
update on comments
tchaton Aug 11, 2021
e428d2f
resolve bug
tchaton Aug 11, 2021
3e97905
typo
tchaton Aug 11, 2021
3b1bdb6
update
tchaton Aug 11, 2021
0d1db29
resolve comments
tchaton Aug 11, 2021
3d35c82
add unittesting
tchaton Aug 11, 2021
4c0f960
resolve tests
tchaton Aug 11, 2021
d3a62ca
resolve comments
tchaton Aug 12, 2021
39781a1
update on comments
tchaton Aug 13, 2021
68c03de
doc updates
tchaton Aug 13, 2021
b01828b
update
tchaton Aug 13, 2021
d213c73
Merge branch 'master' into lightning_cli_registries
tchaton Aug 13, 2021
5935ec4
update on comments
tchaton Aug 17, 2021
b6616f0
Merge branch 'lightning_cli_registries' of https://github.com/PyTorch…
tchaton Aug 17, 2021
0d89423
Merge branch 'master' into lightning_cli_registries
carmocca Aug 19, 2021
37fd679
Fix mypy
carmocca Aug 19, 2021
f16db3d
Revert unrelated change which had broken mypy
carmocca Aug 19, 2021
572488c
Convert to staticmethod
carmocca Aug 19, 2021
2fc4608
Replace context managers for functional static transformations
carmocca Aug 19, 2021
9f383dc
Split tests
carmocca Aug 19, 2021
2a7dfa8
Refactor optimizer tests
carmocca Aug 19, 2021
423ab7b
Cleaning tests
carmocca Aug 19, 2021
7c2e39e
Delete broken test
carmocca Aug 19, 2021
048e159
Docs improvements
carmocca Aug 19, 2021
86fce55
Docs improvements
carmocca Aug 19, 2021
624b0d8
Restructure docs
carmocca Aug 19, 2021
2cc0dc5
Docs for callbacks
carmocca Aug 19, 2021
f9b49fe
Add reload test when add_optimizer_args is added by the user
carmocca Aug 19, 2021
afcc4ba
Add failing config test - needs to be fixed
carmocca Aug 19, 2021
9f41b88
Merge branch 'master' into lightning_cli_registries
carmocca Aug 28, 2021
0ed4ae8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2021
4dd0732
Use property
carmocca Aug 19, 2021
e0fae4f
Fixes after merge
carmocca Aug 28, 2021
4f053bb
Merge branch 'master' into lightning_cli_registries
carmocca Sep 15, 2021
a22fdb3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 15, 2021
160b3f6
Update jsonargparse version
carmocca Sep 15, 2021
f185c2d
Use properties in registry
carmocca Sep 15, 2021
803385c
Keep hacks together
carmocca Sep 15, 2021
8eb8b05
Add FIXMEs
carmocca Sep 15, 2021
9d84127
add_class_choices
carmocca Sep 15, 2021
33ff2f4
Merge branch 'master' into lightning_cli_registries
carmocca Sep 15, 2021
cf82e1a
Remove contains registry. Avoid nested_key clash for optimizers and l…
carmocca Sep 15, 2021
b1cd083
Remove sanitize argv
carmocca Sep 15, 2021
95d31a7
Better support for new callback format
carmocca Sep 16, 2021
231e0ed
Avoid evaluating
carmocca Sep 16, 2021
2af596f
Minor cleaning
carmocca Sep 16, 2021
6add619
Mark argv as private
carmocca Sep 16, 2021
525358a
Fix mypy
carmocca Sep 16, 2021
84b8120
Fix mypy
carmocca Sep 16, 2021
7e48c0e
Fix mypy
carmocca Sep 16, 2021
40ce3c7
Merge branch 'master' into lightning_cli_registries
carmocca Sep 16, 2021
3e77e8e
Support shorthand notation to instantiate optimizers and learning rat…
carmocca Sep 16, 2021
1512a80
Update CHANGELOG
carmocca Sep 16, 2021
c6b86b1
Fix install
carmocca Sep 16, 2021
6f1600c
Fix install
carmocca Sep 16, 2021
a3a791f
Use release
carmocca Sep 16, 2021
f67a90f
Merge branch 'feat/cli-shorthand-optimizers' into lightning_cli_regis…
carmocca Sep 16, 2021
fedae46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 16, 2021
ee7a068
Introduce set_choices
carmocca Sep 16, 2021
6e67617
Undo change
carmocca Sep 16, 2021
e7f6d61
Replace add_class_choices with set_choices
carmocca Sep 16, 2021
8e87359
Replace add_class_choices with set_choices
carmocca Sep 16, 2021
c74426b
Merge
carmocca Sep 16, 2021
66cdb52
Docstrings
carmocca Sep 16, 2021
9217304
Merge branch 'master' into lightning_cli_registries
carmocca Sep 17, 2021
7b50401
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2021
1406be9
Fix mypy
carmocca Sep 17, 2021
a000446
Undo change
carmocca Sep 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 162 additions & 10 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,26 @@
# limitations under the License.
import inspect
import os
import re
import sys
from argparse import Namespace
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from types import MethodType
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Type, Union
from unittest import mock

from torch.optim import Optimizer

from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities import _JSONARGPARSE_AVAILABLE, warnings
from pytorch_lightning.utilities.cli_registries import (
CALLBACK_REGISTRIES,
OPTIMIZER_REGISTRIES,
Registry,
SCHEDULER_REGISTRIES,
)
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -34,6 +46,31 @@
ArgumentParser = object


@dataclass
class ClassInfo:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""This class is an helper to easily build the mocked command line"""

class_arg: str
cls: str
tchaton marked this conversation as resolved.
Show resolved Hide resolved
class_init_args: List[str] = field(default_factory=lambda: [])

def add_class_init_args(self, args: Dict[str, str]) -> None:
if args != self.class_arg:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.class_init_args.append(args)

@property
def class_init(self) -> Dict[str, str]:
class_init = {}
class_init["class_path"] = self.cls.__module__ + "." + self.cls.__name__
tchaton marked this conversation as resolved.
Show resolved Hide resolved
init_args = {}
for init_arg in self.class_init_args:
separator = "=" if "=" in init_arg else " "
arg_path, value = init_arg.split(separator)
init_args[arg_path.split(".")[-1]] = value
class_init["init_args"] = init_args
return class_init


class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning"""

Expand Down Expand Up @@ -289,6 +326,14 @@ def __init__(
self.fit()
self.after_fit()

@property
def optimizer_registered(self) -> Tuple[Type[Optimizer]]:
return tuple(OPTIMIZER_REGISTRIES.values())

@property
def lr_scheduler_registered(self) -> Tuple[LRSchedulerType]:
return tuple(SCHEDULER_REGISTRIES.values())

def init_parser(self, **kwargs: Any) -> LightningArgumentParser:
"""Method that instantiates the argument parser."""
return LightningArgumentParser(**kwargs)
Expand Down Expand Up @@ -332,6 +377,25 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:

def link_optimizers_and_lr_schedulers(self) -> None:
"""Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``."""
if any(
True
for v in sys.argv
for optim_name in OPTIMIZER_REGISTRIES
if re.match(fr"^--optimizer[^\S+=]*?{optim_name}?", v)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
):
if "optimizer" not in self.parser.groups:
self.parser.add_optimizer_args(self.optimizer_registered)

if any(
True
for v in sys.argv
for sch_name in SCHEDULER_REGISTRIES
if re.match(fr"^--lr_scheduler[^\S+=]*{sch_name}?", v)
):
lr_schdulers = tuple(v for v in SCHEDULER_REGISTRIES.values())
if "lr_scheduler" not in self.parser.groups:
self.parser.add_lr_scheduler_args(lr_schdulers)

for key, (class_type, link_to) in self.parser.optimizers_and_lr_schedulers.items():
if link_to == "AUTOMATIC":
continue
Expand All @@ -341,9 +405,90 @@ def link_optimizers_and_lr_schedulers(self) -> None:
add_class_path = _add_class_path_generator(class_type)
self.parser.link_arguments(key, link_to, compute_fn=add_class_path)

@contextmanager
def prepare_from_registry(self, registry: Registry):
"""
This context manager is used to simplify unique class instantiation.
"""

# find if the users is using shortcut command line.
map_user_key_to_info = {}
for registered_name, registered_cls in registry.items():
for v in sys.argv:
separator = "=" if "=" in v else " "
if f"{separator}{registered_name}" in v:
key = v.split(separator)[0]
map_user_key_to_info[key] = ClassInfo(class_arg=v, cls=registered_cls)

if len(map_user_key_to_info) > 0:
# for each shortcut command line, add its init arguments and skip them from `sys.argv`.
argv = []
for v in sys.argv:
skip = False
for key in map_user_key_to_info:
if key in v:
skip = True
map_user_key_to_info[key].add_class_init_args(v)
if not skip:
argv.append(v)

# re-create the global command line and mock `sys.argv`.
argv += [f"{user_key}={info.class_init}" for user_key, info in map_user_key_to_info.items()]
with mock.patch("sys.argv", argv):
yield
else:
yield

@contextmanager
def prepare_class_list_from_registry(self, pattern: str, registry: Registry):
"""
This context manager is used to simplify instantiation of a list of class.
"""
argv = [v for v in sys.argv if pattern not in v]
all_matched_args = [v for v in sys.argv if pattern in v]
all_simplified_args = [v for v in all_matched_args if f"{pattern}" in v and f"{pattern}=[" not in v]
all_cls_simplified_args = [v for v in all_simplified_args if f"{pattern}=" in v]
all_non_simplified_args = [v for v in all_matched_args if f"{pattern}=" in v and f"{pattern}=[" in v]

num_simplified_cls = len(all_simplified_args)
should_replace = num_simplified_cls > 0 and not all("class_path" in v for v in all_matched_args)

if should_replace:
# verify the user is properly ordering arguments.
assert all_cls_simplified_args[0] == all_simplified_args[0]
if len(all_non_simplified_args) > 1:
raise MisconfigurationException(f"When provided {pattern} as list, please group them under 1 argument.")

# group arguments per callbacks
infos = []
for class_arg in all_cls_simplified_args:
class_name = class_arg.split("=")[1]
registered_cls = registry[class_name]
infos.append(ClassInfo(class_arg=class_arg, cls=registered_cls))

for v in all_simplified_args:
if v in all_cls_simplified_args:
current_info = infos[all_cls_simplified_args.index(v)]
current_info.add_class_init_args(v)

class_args = [info.class_init for info in infos]
# add other callback arguments.
class_args.extend(eval(all_non_simplified_args[0].split("=")[-1]))

argv += [f"{pattern}={class_args}"]
with mock.patch("sys.argv", argv):
yield
else:
yield

def parse_arguments(self, parser: LightningArgumentParser) -> None:
"""Parses command line arguments and stores it in ``self.config``."""
self.config = parser.parse_args()
# fmt: off
with self.prepare_from_registry(OPTIMIZER_REGISTRIES), \
self.prepare_from_registry(SCHEDULER_REGISTRIES), \
self.prepare_class_list_from_registry("--trainer.callbacks", CALLBACK_REGISTRIES):
self.config = parser.parse_args()
# fmt: on

def before_instantiate_classes(self) -> None:
"""Implement to run some code before instantiating the classes."""
Expand Down Expand Up @@ -375,7 +520,6 @@ def instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback])
def add_configure_optimizers_method_to_model(self) -> None:
"""
Adds to the model an automatically generated ``configure_optimizers`` method.

If a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC',
then a `configure_optimizers` method is automatically implemented in the model class.
"""
Expand Down Expand Up @@ -421,17 +565,25 @@ def get_automatic(class_type: Union[Type, Tuple[Type, ...]]) -> List[str]:
if not isinstance(lr_scheduler_class, tuple):
lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init)

def configure_optimizers(
self: LightningModule,
) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]:
optimizer = instantiate_class(self.parameters(), optimizer_init)
if not lr_scheduler_init:
return optimizer
lr_scheduler = instantiate_class(optimizer, lr_scheduler_init)
return [optimizer], [lr_scheduler]
configure_optimizers = partial(
self.configure_optimizers, optimizer_init=optimizer_init, lr_scheduler_init=lr_scheduler_init
)
configure_optimizers.__code__ = self.model.configure_optimizers.__code__

self.model.configure_optimizers = MethodType(configure_optimizers, self.model)

@staticmethod
def configure_optimizers(
pl_module: LightningModule,
optimizer_init: Union[str, List[str]],
lr_scheduler_init: Optional[Union[str, List[str]]] = None,
) -> Union[Optimizer, Tuple[List[Optimizer], List[LRSchedulerType]]]:
optimizer = instantiate_class(pl_module.parameters(), optimizer_init)
if not lr_scheduler_init:
return optimizer
lr_scheduler = instantiate_class(optimizer, lr_scheduler_init)
return [optimizer], [lr_scheduler]

def prepare_fit_kwargs(self) -> None:
"""Prepares fit_kwargs including datamodule using self.config_init['data'] if given"""
self.fit_kwargs = {"model": self.model}
Expand Down
79 changes: 79 additions & 0 deletions pytorch_lightning/utilities/cli_registries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The PyTorch Lightning team.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections import UserDict
from typing import Callable, List, Optional, Type

import torch

import pytorch_lightning as pl
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class Registry(UserDict):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def __call__(
self,
cls: Optional[Type] = None,
key: Optional[str] = None,
override: bool = False,
) -> Callable:
"""
Registers a plugin mapped to a name and with required metadata.

Args:
key : the name that identifies a plugin, e.g. "deepspeed_stage_3"
value : plugin class
"""
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if key is None:
key = cls.__name__
elif not isinstance(key, str):
raise TypeError(f"`key` must be a str, found {key}")

if key in self and not override:
raise MisconfigurationException(f"'{key}' is already present in the registry. HINT: Use `override=True`.")

def do_register(key, cls) -> Callable:
self[key] = cls
return cls

do_register(key, cls)

return do_register

def register_package(self, module, base_cls: Type) -> None:
for obj_name in dir(module):
obj_cls = getattr(module, obj_name)
if inspect.isclass(obj_cls) and issubclass(obj_cls, base_cls):
self(cls=obj_cls)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def remove(self, name: str) -> None:
"""Removes the registered plugin by name"""
self.pop(name)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def available_objects(self) -> List:
"""Returns a list of registered plugins"""
return list(self.keys())

def __str__(self) -> str:
return "Registered Plugins: {}".format(", ".join(self.keys()))


CALLBACK_REGISTRIES = Registry()
CALLBACK_REGISTRIES.register_package(pl.callbacks, pl.callbacks.Callback)

OPTIMIZER_REGISTRIES = Registry()
OPTIMIZER_REGISTRIES.register_package(torch.optim, torch.optim.Optimizer)

SCHEDULER_REGISTRIES = Registry()
SCHEDULER_REGISTRIES.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
Loading