Skip to content

Cautious optimizer impl plus some typing cleanup. #2349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_optim.py
Original file line number Diff line number Diff line change
@@ -298,7 +298,7 @@ def test_optim_factory(optimizer):
assert isinstance(opt_info, OptimInfo)

lr = (1e-2,) * 4
if optimizer in ('mars',):
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
lr = (1e-3,) * 4

try:
137 changes: 107 additions & 30 deletions timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
@@ -5,15 +5,16 @@
import logging
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from fnmatch import fnmatch
import importlib

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim

from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
from ._types import ParamsT, OptimType, OptimizerCallable
from .adabelief import AdaBelief
from .adafactor import Adafactor
from .adafactor_bv import AdafactorBigVision
@@ -39,11 +40,6 @@

_logger = logging.getLogger(__name__)

# Type variables
T = TypeVar('T')
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
OptimType = TypeVar('OptimType', bound='optim.Optimizer')


def _import_class(class_string: str) -> Type:
"""Dynamically import a class from a string."""
@@ -55,11 +51,6 @@ def _import_class(class_string: str) -> Type:
raise ImportError(f"Could not import {class_string}: {e}")


class OptimizerCallable(Protocol):
"""Protocol for optimizer constructor signatures."""

def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...


@dataclass(frozen=True)
class OptimInfo:
@@ -76,7 +67,7 @@ class OptimInfo:
defaults: Optional default parameters for the optimizer
"""
name: str
opt_class: Union[str, Type[optim.Optimizer]]
opt_class: Union[str, OptimType]
description: str = ''
has_eps: bool = True
has_momentum: bool = False
@@ -185,7 +176,7 @@ def get_optimizer_class(
self,
name_or_info: Union[str, OptimInfo],
bind_defaults: bool = True,
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
) -> Union[OptimType, OptimizerCallable]:
"""Get the optimizer class with any default arguments applied.
This allows direct instantiation of optimizers with their default configs
@@ -234,17 +225,17 @@ def get_optimizer_class(

def create_optimizer(
self,
model_or_params: Union[nn.Module, Params],
model_or_params: Union[nn.Module, ParamsT],
opt: str,
lr: Optional[float] = None,
weight_decay: float = 0.,
momentum: float = 0.9,
foreach: Optional[bool] = None,
weight_decay_exclude_1d: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
**kwargs: Any,
) -> optim.Optimizer:
) -> torch.optim.Optimizer:
"""Create an optimizer instance.
Args:
@@ -347,15 +338,15 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
sgd_optimizers = [
OptimInfo(
name='sgd',
opt_class=optim.SGD,
opt_class=torch.optim.SGD,
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True}
),
OptimInfo(
name='momentum',
opt_class=optim.SGD,
opt_class=torch.optim.SGD,
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
has_eps=False,
has_momentum=True,
@@ -386,13 +377,13 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
adam_optimizers = [
OptimInfo(
name='adam',
opt_class=optim.Adam,
opt_class=torch.optim.Adam,
description='torch.optim.Adam, Adaptive Moment Estimation',
has_betas=True
),
OptimInfo(
name='adamw',
opt_class=optim.AdamW,
opt_class=torch.optim.AdamW,
description='torch.optim.AdamW, Adam with decoupled weight decay',
has_betas=True
),
@@ -448,7 +439,7 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
),
OptimInfo(
name='adamax',
opt_class=optim.Adamax,
opt_class=torch.optim.Adamax,
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
has_betas=True
),
@@ -526,6 +517,87 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
registry.register(opt)


def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
cautious_optimizers = [
OptimInfo(
name='cadafactor',
opt_class=Adafactor,
description='Cautious Adafactor',
defaults={'caution': True}
),
OptimInfo(
name='cadafactorbv',
opt_class=AdafactorBigVision,
description='Cautious Big Vision Adafactor',
defaults={'caution': True}
),
OptimInfo(
name='cadamw',
opt_class=AdamWLegacy,
description='Cautious AdamW',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='cadopt',
opt_class=Adopt,
description='Cautious Adopt',
defaults={'caution': True}
),
OptimInfo(
name='cadoptw',
opt_class=Adopt,
description='Cautious AdoptW (decoupled decay)',
defaults={'decoupled': True, 'caution': True}
),
OptimInfo(
name='clamb',
opt_class=Lamb,
description='Cautious LAMB',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='claprop',
opt_class=LaProp,
description='Cautious LaProp',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='clion',
opt_class=Lion,
description='Cautious Lion',
has_eps=False,
has_betas=True,
defaults = {'caution': True}
),
OptimInfo(
name='cnadamw',
opt_class=NAdamW,
description='Cautious NAdamW',
has_betas=True,
defaults={'caution': True}
),
OptimInfo(
name='crmsproptf',
opt_class=RMSpropTF,
description='Cautious TensorFlow-style RMSprop',
has_momentum=True,
defaults={'alpha': 0.9, 'caution': True}
),
OptimInfo(
name='csgdw',
opt_class=SGDW,
description='Cautious SGD with decoupled weight decay and Nesterov momentum',
has_eps=False,
has_momentum=True,
defaults={'nesterov': True, 'caution': True}
),
]
for opt in cautious_optimizers:
registry.register(opt)

def _register_other_optimizers(registry: OptimizerRegistry) -> None:
"""Register miscellaneous optimizers"""
other_optimizers = [
@@ -545,12 +617,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
),
OptimInfo(
name='adadelta',
opt_class=optim.Adadelta,
opt_class=torch.optim.Adadelta,
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
),
OptimInfo(
name='adagrad',
opt_class=optim.Adagrad,
opt_class=torch.optim.Adagrad,
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
defaults={'eps': 1e-8}
),
@@ -617,7 +689,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
),
OptimInfo(
name='rmsprop',
opt_class=optim.RMSprop,
opt_class=torch.optim.RMSprop,
description='torch.optim.RMSprop, Root Mean Square Propagation',
has_momentum=True,
defaults={'alpha': 0.9}
@@ -765,6 +837,7 @@ def _register_default_optimizers() -> None:
_register_other_optimizers(default_registry)
_register_apex_optimizers(default_registry)
_register_bnb_optimizers(default_registry)
_register_cautious_optimizers(default_registry)

# Register aliases
default_registry.register_alias('nesterov', 'sgd')
@@ -839,7 +912,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
def get_optimizer_class(
name: str,
bind_defaults: bool = True,
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
) -> Union[OptimType, OptimizerCallable]:
"""Get optimizer class by name with option to bind default arguments.
Retrieves the optimizer class or a partial function with default arguments bound.
@@ -874,17 +947,17 @@ def get_optimizer_class(


def create_optimizer_v2(
model_or_params: Union[nn.Module, Params],
model_or_params: Union[nn.Module, ParamsT],
opt: str = 'sgd',
lr: Optional[float] = None,
weight_decay: float = 0.,
momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
**kwargs: Any,
) -> optim.Optimizer:
) -> torch.optim.Optimizer:
"""Create an optimizer instance via timm registry.
Creates and configures an optimizer with appropriate parameter groups and settings.
@@ -985,7 +1058,11 @@ def optimizer_kwargs(cfg):
return kwargs


def create_optimizer(args, model, filter_bias_and_bn=True):
def create_optimizer(
args,
model: Union[nn.Module, ParamsT],
filter_bias_and_bn: bool = True,
) -> torch.optim.Optimizer:
""" Legacy optimizer factory for backwards compatibility.
NOTE: Use create_optimizer_v2 for new code.
"""
25 changes: 25 additions & 0 deletions timm/optim/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any, Dict, Iterable, Union, Protocol, Type
try:
from typing import TypeAlias, TypeVar
except ImportError:
from typing_extensions import TypeAlias, TypeVar

import torch
import torch.optim

try:
from torch.optim.optimizer import ParamsT
except (ImportError, TypeError):
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]


OptimType = Type[torch.optim.Optimizer]


class OptimizerCallable(Protocol):
"""Protocol for optimizer constructor signatures."""

def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ...


__all__ = ['ParamsT', 'OptimType', 'OptimizerCallable']
Loading