From 3086dd03fdd95bd767d943261130b00333aa1405 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 28 Nov 2024 12:34:51 -0800 Subject: [PATCH 1/3] Cautious optimizer impl plus some typing cleanup. --- tests/test_optim.py | 2 +- timm/optim/_optim_factory.py | 137 +++++++++++++++++++++++++++-------- timm/optim/_types.py | 25 +++++++ timm/optim/adafactor.py | 72 +++++++++++------- timm/optim/adafactor_bv.py | 19 ++++- timm/optim/adamw.py | 62 ++++++++-------- timm/optim/adopt.py | 38 ++++++++-- timm/optim/lamb.py | 72 ++++++++++-------- timm/optim/laprop.py | 23 ++++-- timm/optim/lion.py | 60 ++++++++++----- timm/optim/nadamw.py | 105 ++++++++++++++++++--------- timm/optim/rmsprop_tf.py | 67 ++++++++++------- timm/optim/sgdw.py | 83 ++++++++++++++------- 13 files changed, 526 insertions(+), 239 deletions(-) create mode 100644 timm/optim/_types.py diff --git a/tests/test_optim.py b/tests/test_optim.py index d9827ae84a..1ec227248f 100644 --- a/tests/test_optim.py +++ b/tests/test_optim.py @@ -298,7 +298,7 @@ def test_optim_factory(optimizer): assert isinstance(opt_info, OptimInfo) lr = (1e-2,) * 4 - if optimizer in ('mars',): + if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'): lr = (1e-3,) * 4 try: diff --git a/timm/optim/_optim_factory.py b/timm/optim/_optim_factory.py index f4784d7c05..ccf2b34cb2 100644 --- a/timm/optim/_optim_factory.py +++ b/timm/optim/_optim_factory.py @@ -5,15 +5,16 @@ import logging from dataclasses import dataclass from functools import partial -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union from fnmatch import fnmatch import importlib import torch import torch.nn as nn -import torch.optim as optim +import torch.optim from ._param_groups import param_groups_layer_decay, param_groups_weight_decay +from ._types import ParamsT, OptimType, OptimizerCallable from .adabelief import AdaBelief from .adafactor import Adafactor from .adafactor_bv import AdafactorBigVision @@ -39,11 +40,6 @@ _logger = logging.getLogger(__name__) -# Type variables -T = TypeVar('T') -Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]] -OptimType = TypeVar('OptimType', bound='optim.Optimizer') - def _import_class(class_string: str) -> Type: """Dynamically import a class from a string.""" @@ -55,11 +51,6 @@ def _import_class(class_string: str) -> Type: raise ImportError(f"Could not import {class_string}: {e}") -class OptimizerCallable(Protocol): - """Protocol for optimizer constructor signatures.""" - - def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ... - @dataclass(frozen=True) class OptimInfo: @@ -76,7 +67,7 @@ class OptimInfo: defaults: Optional default parameters for the optimizer """ name: str - opt_class: Union[str, Type[optim.Optimizer]] + opt_class: Union[str, OptimType] description: str = '' has_eps: bool = True has_momentum: bool = False @@ -185,7 +176,7 @@ def get_optimizer_class( self, name_or_info: Union[str, OptimInfo], bind_defaults: bool = True, - ) -> Union[Type[optim.Optimizer], OptimizerCallable]: + ) -> Union[OptimType, OptimizerCallable]: """Get the optimizer class with any default arguments applied. This allows direct instantiation of optimizers with their default configs @@ -234,7 +225,7 @@ def get_optimizer_class( def create_optimizer( self, - model_or_params: Union[nn.Module, Params], + model_or_params: Union[nn.Module, ParamsT], opt: str, lr: Optional[float] = None, weight_decay: float = 0., @@ -242,9 +233,9 @@ def create_optimizer( foreach: Optional[bool] = None, weight_decay_exclude_1d: bool = True, layer_decay: Optional[float] = None, - param_group_fn: Optional[Callable[[nn.Module], Params]] = None, + param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None, **kwargs: Any, - ) -> optim.Optimizer: + ) -> torch.optim.Optimizer: """Create an optimizer instance. Args: @@ -347,7 +338,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None: sgd_optimizers = [ OptimInfo( name='sgd', - opt_class=optim.SGD, + opt_class=torch.optim.SGD, description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum', has_eps=False, has_momentum=True, @@ -355,7 +346,7 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None: ), OptimInfo( name='momentum', - opt_class=optim.SGD, + opt_class=torch.optim.SGD, description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum', has_eps=False, has_momentum=True, @@ -386,13 +377,13 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None: adam_optimizers = [ OptimInfo( name='adam', - opt_class=optim.Adam, + opt_class=torch.optim.Adam, description='torch.optim.Adam, Adaptive Moment Estimation', has_betas=True ), OptimInfo( name='adamw', - opt_class=optim.AdamW, + opt_class=torch.optim.AdamW, description='torch.optim.AdamW, Adam with decoupled weight decay', has_betas=True ), @@ -448,7 +439,7 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None: ), OptimInfo( name='adamax', - opt_class=optim.Adamax, + opt_class=torch.optim.Adamax, description='torch.optim.Adamax, Adam with infinity norm for more stable updates', has_betas=True ), @@ -526,6 +517,87 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None: registry.register(opt) +def _register_cautious_optimizers(registry: OptimizerRegistry) -> None: + cautious_optimizers = [ + OptimInfo( + name='cadafactor', + opt_class=Adafactor, + description='Cautious Adafactor', + defaults={'caution': True} + ), + OptimInfo( + name='cadafactorbv', + opt_class=AdafactorBigVision, + description='Cautious Big Vision Adafactor', + defaults={'caution': True} + ), + OptimInfo( + name='cadamw', + opt_class=AdamWLegacy, + description='Cautious AdamW', + has_betas=True, + defaults={'caution': True} + ), + OptimInfo( + name='cadopt', + opt_class=Adopt, + description='Cautious Adopt', + defaults={'caution': True} + ), + OptimInfo( + name='cadoptw', + opt_class=Adopt, + description='Cautious AdoptW (decoupled decay)', + defaults={'decoupled': True, 'caution': True} + ), + OptimInfo( + name='clamb', + opt_class=Lamb, + description='Cautious LAMB', + has_betas=True, + defaults={'caution': True} + ), + OptimInfo( + name='claprop', + opt_class=LaProp, + description='Cautious LaProp', + has_betas=True, + defaults={'caution': True} + ), + OptimInfo( + name='clion', + opt_class=Lion, + description='Cautious Lion', + has_eps=False, + has_betas=True, + defaults = {'caution': True} + ), + OptimInfo( + name='cnadamw', + opt_class=NAdamW, + description='Cautious NAdamW', + has_betas=True, + defaults={'caution': True} + ), + OptimInfo( + name='crmsproptf', + opt_class=RMSpropTF, + description='Cautious TensorFlow-style RMSprop', + has_momentum=True, + defaults={'alpha': 0.9, 'caution': True} + ), + OptimInfo( + name='csgdw', + opt_class=SGDW, + description='Cautious SGD with decoupled weight decay and Nesterov momentum', + has_eps=False, + has_momentum=True, + defaults={'nesterov': True, 'caution': True} + ), + ] + for opt in cautious_optimizers: + registry.register(opt) + def _register_other_optimizers(registry: OptimizerRegistry) -> None: """Register miscellaneous optimizers""" other_optimizers = [ @@ -545,12 +617,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: ), OptimInfo( name='adadelta', - opt_class=optim.Adadelta, + opt_class=torch.optim.Adadelta, description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients' ), OptimInfo( name='adagrad', - opt_class=optim.Adagrad, + opt_class=torch.optim.Adagrad, description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients', defaults={'eps': 1e-8} ), @@ -617,7 +689,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None: ), OptimInfo( name='rmsprop', - opt_class=optim.RMSprop, + opt_class=torch.optim.RMSprop, description='torch.optim.RMSprop, Root Mean Square Propagation', has_momentum=True, defaults={'alpha': 0.9} @@ -765,6 +837,7 @@ def _register_default_optimizers() -> None: _register_other_optimizers(default_registry) _register_apex_optimizers(default_registry) _register_bnb_optimizers(default_registry) + _register_cautious_optimizers(default_registry) # Register aliases default_registry.register_alias('nesterov', 'sgd') @@ -839,7 +912,7 @@ def get_optimizer_info(name: str) -> OptimInfo: def get_optimizer_class( name: str, bind_defaults: bool = True, -) -> Union[Type[optim.Optimizer], OptimizerCallable]: +) -> Union[OptimType, OptimizerCallable]: """Get optimizer class by name with option to bind default arguments. Retrieves the optimizer class or a partial function with default arguments bound. @@ -874,7 +947,7 @@ def get_optimizer_class( def create_optimizer_v2( - model_or_params: Union[nn.Module, Params], + model_or_params: Union[nn.Module, ParamsT], opt: str = 'sgd', lr: Optional[float] = None, weight_decay: float = 0., @@ -882,9 +955,9 @@ def create_optimizer_v2( foreach: Optional[bool] = None, filter_bias_and_bn: bool = True, layer_decay: Optional[float] = None, - param_group_fn: Optional[Callable[[nn.Module], Params]] = None, + param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None, **kwargs: Any, -) -> optim.Optimizer: +) -> torch.optim.Optimizer: """Create an optimizer instance via timm registry. Creates and configures an optimizer with appropriate parameter groups and settings. @@ -985,7 +1058,11 @@ def optimizer_kwargs(cfg): return kwargs -def create_optimizer(args, model, filter_bias_and_bn=True): +def create_optimizer( + args, + model: Union[nn.Module, ParamsT], + filter_bias_and_bn: bool = True, +) -> torch.optim.Optimizer: """ Legacy optimizer factory for backwards compatibility. NOTE: Use create_optimizer_v2 for new code. """ diff --git a/timm/optim/_types.py b/timm/optim/_types.py new file mode 100644 index 0000000000..c24eddd108 --- /dev/null +++ b/timm/optim/_types.py @@ -0,0 +1,25 @@ +from typing import Any, Dict, Iterable, Union, Protocol, Type +try: + from typing import TypeAlias, TypeVar +except ImportError: + from typing_extensions import TypeAlias, TypeVar + +import torch +import torch.optim + +try: + from torch.optim.optimizer import ParamsT +except (ImportError, TypeError): + ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] + + +OptimType = Type[torch.optim.Optimizer] + + +class OptimizerCallable(Protocol): + """Protocol for optimizer constructor signatures.""" + + def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ... + + +__all__ = ['ParamsT', 'OptimType', 'OptimizerCallable'] \ No newline at end of file diff --git a/timm/optim/adafactor.py b/timm/optim/adafactor.py index 01c25ff2fb..e11b0a9f0b 100644 --- a/timm/optim/adafactor.py +++ b/timm/optim/adafactor.py @@ -10,8 +10,12 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch import math +from typing import Optional, Tuple + +import torch + +from ._types import ParamsT class Adafactor(torch.optim.Optimizer): @@ -26,33 +30,33 @@ class Adafactor(torch.optim.Optimizer): To use a manual (external) learning rate schedule you should set `scale_parameter=False` and `relative_step=False`. - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining parameter groups - lr (float, optional): external learning rate (default: None) - eps (tuple[float, float]): regularization constants for square gradient - and parameter scale respectively (default: (1e-30, 1e-3)) - clip_threshold (float): threshold of root mean square of final gradient update (default: 1.0) - decay_rate (float): coefficient used to compute running averages of square gradient (default: -0.8) - beta1 (float): coefficient used for computing running averages of gradient (default: None) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - scale_parameter (bool): if True, learning rate is scaled by root mean square of parameter (default: True) - warmup_init (bool): time-dependent learning rate computation depends on - whether warm-up initialization is being used (default: False) + Ags: + params: iterable of parameters to optimize or dicts defining parameter groups + lr: external learning rate + eps: regularization constants for square gradient and parameter scale respectively + eps_scale: regularization constants for parameter scale respectively + clip_threshold: threshold of root-mean-square of final gradient update + decay_rate: coefficient used to compute running averages of square gradient + beta1: coefficient used for computing running averages of gradient + weight_decay: weight decay + scale_parameter: if True, learning rate is scaled by root-mean-square of parameter + warmup_init: time-dependent learning rate computation depends on whether warm-up initialization is being used """ def __init__( self, - params, - lr=None, - eps=1e-30, - eps_scale=1e-3, - clip_threshold=1.0, - decay_rate=-0.8, - betas=None, - weight_decay=0.0, - scale_parameter=True, - warmup_init=False, - min_dim_size_to_factor=32, + params: ParamsT, + lr: Optional[float] = None, + eps: float = 1e-30, + eps_scale: float = 1e-3, + clip_threshold: float = 1.0, + decay_rate: float = -0.8, + betas: Optional[Tuple[float, float]] = None, + weight_decay: float = 0.0, + scale_parameter: bool = True, + warmup_init: bool = False, + min_dim_size_to_factor: int = 16, + caution: bool = False, ): relative_step = not lr if warmup_init and not relative_step: @@ -71,9 +75,16 @@ def __init__( relative_step=relative_step, warmup_init=warmup_init, min_dim_size_to_factor=min_dim_size_to_factor, + caution=caution, ) super(Adafactor, self).__init__(params, defaults) + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('caution', False) + group.setdefault('min_dim_size_to_factor', 32) + @staticmethod def _get_lr(param_group, param_state): if param_group['relative_step']: @@ -86,7 +97,7 @@ def _get_lr(param_group, param_state): return param_group['lr'] @staticmethod - def _get_options(param_group, param_shape, min_size_to_factor=32): + def _get_options(param_group, param_shape, min_size_to_factor=16): use_first_moment = param_group['beta1'] is not None factored = None ndim = len(param_shape) @@ -98,7 +109,7 @@ def _get_options(param_group, param_shape, min_size_to_factor=32): # nD convs in torch are ND + 2 dim weights with leading in/out chs factored = 0, 1 elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor: - # if the criteria above didn't match, test trailing dims for eligibility + # if the criteria above didn't match, test trailing dims for eligibility as per original impl factored = ndim - 2, ndim - 1 return factored, use_first_moment @@ -113,7 +124,6 @@ def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, dim_col, dim_row): c_factor = exp_avg_sq_col.unsqueeze(dim_col).rsqrt() return torch.mul(r_factor, c_factor) - @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -201,7 +211,13 @@ def _remove_dim(shape, dim): if use_first_moment: exp_avg = state['exp_avg'] exp_avg.mul_(group['beta1']).add_(update, alpha=1 - group['beta1']) - update = exp_avg + if group['caution']: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (exp_avg * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + update = exp_avg * mask + else: + update = exp_avg if group['weight_decay'] != 0: p_fp32.add_(p_fp32, alpha=-group['weight_decay'] * lr_t) diff --git a/timm/optim/adafactor_bv.py b/timm/optim/adafactor_bv.py index 3bb6e9592b..298d43bb7e 100644 --- a/timm/optim/adafactor_bv.py +++ b/timm/optim/adafactor_bv.py @@ -6,13 +6,14 @@ Adaptation and PyTorch modifications by Ross Wightman """ - from typing import List, Optional, Tuple, Union import torch from torch import Tensor from torch.optim import Optimizer +from ._types import ParamsT + def _get_scalar_dtype(): """Get the scalar dtype that the optimizer uses for state""" @@ -54,9 +55,9 @@ class AdafactorBigVision(Optimizer): def __init__( self, - params, + params: ParamsT, lr: float = 1.0, - min_dim_size_to_factor: int = 32, + min_dim_size_to_factor: int = 16, decay_rate: float = 0.8, decay_offset: int = 0, beta2_cap: float = 0.999, @@ -66,6 +67,7 @@ def __init__( weight_decay: float = 0.0, clipping_threshold: Optional[float] = None, unscaled_wd: bool = False, + caution: bool = False, *, foreach: Optional[bool] = False, ): @@ -91,6 +93,7 @@ def __init__( weight_decay=weight_decay, clipping_threshold=clipping_threshold, unscaled_wd=unscaled_wd, + caution=caution, foreach=foreach, ) super().__init__(params, defaults) @@ -98,6 +101,7 @@ def __init__( def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: + group.setdefault('caution', False) group.setdefault('foreach', None) for p in group['params']: p_state = self.state.get(p, {}) @@ -192,6 +196,7 @@ def step(self, closure=None): momentum_dtype=group['momentum_dtype'], clipping_threshold=group['clipping_threshold'], unscaled_wd=group['unscaled_wd'], + caution=group['caution'], ) return loss @@ -216,6 +221,7 @@ def _single_tensor_adafactor( momentum_dtype: Union[str, torch.dtype], clipping_threshold: Optional[float], unscaled_wd: bool, + caution: bool, ): for i, param in enumerate(params): grad = grads[i] @@ -267,6 +273,12 @@ def _single_tensor_adafactor( exp_avg.lerp_(update, 1 - momentum) # ema update = exp_avg.clone() + if caution: + # apply caution as per 'Cautious Optimizers': https://arxiv.org/abs/2411.16085 + mask = (update * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + update.mul_(mask) + # Scale by learning rate update.mul_(lr) @@ -302,6 +314,7 @@ def _multi_tensor_adafactor( momentum_dtype: Union[str, torch.dtype], clipping_threshold: Optional[float], unscaled_wd: bool, + caution: bool, ): # FIXME TODO assert False, 'multi-tensor fn (foreach=True) not implemented yet' diff --git a/timm/optim/adamw.py b/timm/optim/adamw.py index fe34609c82..07299ad63e 100644 --- a/timm/optim/adamw.py +++ b/timm/optim/adamw.py @@ -4,49 +4,45 @@ NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference """ import math +from typing import Tuple + import torch from torch.optim.optimizer import Optimizer +from ._types import ParamsT + class AdamWLegacy(Optimizer): r"""Implements AdamW algorithm. NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + References: + - Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 + - Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 + - On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ + + Args: + params: iterable of parameters to optimize or dicts defining parameter groups + lr: learning rate + betas: coefficients used for computing running averages of gradient and its square + eps: term added to the denominator to improve numerical stability + weight_decay: weight decay coefficient + amsgrad: whether to use the AMSGrad variant of this algorithm + from the paper `On the Convergence of Adam and Beyond` + caution: apply caution when using AdamW """ def __init__( self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, - amsgrad=False, + params: ParamsT, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + caution: bool = False, ): - # NOTE: deprecated in favour of builtin torch.optim.AdamW if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: @@ -61,6 +57,7 @@ def __init__( eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, + caution=caution, ) super(AdamWLegacy, self).__init__(params, defaults) @@ -68,6 +65,7 @@ def __setstate__(self, state): super(AdamWLegacy, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) + group.setdefault('caution', False) @torch.no_grad() def step(self, closure=None): @@ -131,6 +129,12 @@ def step(self, closure=None): step_size = group['lr'] / bias_correction1 + if group['caution']: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (exp_avg * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + exp_avg = exp_avg * mask + p.addcdiv_(exp_avg, denom, value=-step_size) return loss diff --git a/timm/optim/adopt.py b/timm/optim/adopt.py index 5cc7c18b99..6192990ed6 100644 --- a/timm/optim/adopt.py +++ b/timm/optim/adopt.py @@ -10,16 +10,15 @@ title = {ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate}, year = {2024} } - """ - -from typing import cast, Callable, List, Optional, Tuple, Union +from typing import cast, List, Optional, Tuple, Union import torch from torch import Tensor - from torch.optim.optimizer import Optimizer +from ._types import ParamsT + __all__ = ["Adopt", "adopt"] def _view_as_real(params, *state_and_grads): @@ -60,7 +59,7 @@ class Adopt(Optimizer): """ def __init__( self, - params, + params: ParamsT, lr: Union[float, Tensor] = 1e-3, betas: Tuple[float, float] = (0.9, 0.9999), eps: float = 1e-6, @@ -68,7 +67,8 @@ def __init__( weight_decay: float = 0.0, decoupled: bool = False, *, - foreach: Optional[bool] = None, + caution: bool = False, + foreach: Optional[bool] = False, maximize: bool = False, capturable: bool = False, differentiable: bool = False, @@ -98,6 +98,7 @@ def __init__( weight_decay=weight_decay, clip_exp=clip_exp, decoupled=decoupled, + caution=caution, maximize=maximize, foreach=foreach, capturable=capturable, @@ -105,7 +106,6 @@ def __init__( ) super().__init__(params, defaults) - def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: @@ -114,6 +114,7 @@ def __setstate__(self, state): group.setdefault("capturable", False) group.setdefault("differentiable", False) group.setdefault("clip_exp", None) + group.setdefault("caution", False) for p in group["params"]: p_state = self.state.get(p, []) if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): @@ -223,6 +224,7 @@ def step(self, closure=None): clip_exp=group["clip_exp"], decoupled=group["decoupled"], eps=group["eps"], + caution=group["caution"], maximize=group["maximize"], foreach=group["foreach"], capturable=group["capturable"], @@ -251,6 +253,7 @@ def _single_tensor_adopt( clip_exp: Optional[float], decoupled: bool, eps: float, + caution: bool, maximize: bool, capturable: bool, differentiable: bool, @@ -306,6 +309,13 @@ def _single_tensor_adopt( normed_grad.clamp_(-clip_val, clip_val) exp_avg.lerp_(normed_grad, 1 - beta1) + + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (exp_avg * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + exp_avg = exp_avg * mask + param.add_(exp_avg, alpha=-lr) exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) @@ -328,6 +338,7 @@ def _multi_tensor_adopt( clip_exp: Optional[float], decoupled: bool, eps: float, + caution: bool, maximize: bool, capturable: bool, differentiable: bool, @@ -403,6 +414,7 @@ def _multi_tensor_adopt( exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) torch._foreach_maximum_(exp_avg_sq_sqrt, eps) + normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt) if clip_exp is not None: @@ -411,6 +423,16 @@ def _multi_tensor_adopt( torch._foreach_minimum_(normed_grad, clip_val) torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1) + + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + masks = torch._foreach_mul(device_exp_avgs, device_grads) + masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)] + mask_scale = [m.mean() for m in masks] + torch._foreach_maximum_(mask_scale, 1e-3) + torch._foreach_div_(masks, mask_scale) + device_exp_avgs = torch._foreach_mul(device_exp_avgs, masks) + torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr) torch._foreach_mul_(device_exp_avg_sqs, beta2) @@ -440,6 +462,7 @@ def adopt( clip_exp: Optional[float], decoupled: bool, eps: float, + caution: bool, maximize: bool, ): r"""Functional API that performs ADOPT algorithm computation. @@ -477,6 +500,7 @@ def adopt( clip_exp=clip_exp, decoupled=decoupled, eps=eps, + caution=caution, maximize=maximize, capturable=capturable, differentiable=differentiable, diff --git a/timm/optim/lamb.py b/timm/optim/lamb.py index 9d3a3421df..ee89225ec6 100644 --- a/timm/optim/lamb.py +++ b/timm/optim/lamb.py @@ -52,50 +52,48 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. import math +from typing import Optional, Tuple import torch from torch.optim import Optimizer +from ._types import ParamsT + class Lamb(Optimizer): """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py - LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its norm. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - grad_averaging (bool, optional): whether apply (1-beta2) to grad when - calculating running averages of gradient. (default: True) - max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) - trust_clip (bool): enable LAMBC trust ratio clipping (default: False) - always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 - weight decay parameter (default: False) - - .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: - https://arxiv.org/abs/1904.00962 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + LAMB was proposed in: + - Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962 + - On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ + + Args: + params: Iterable of parameters to optimize or dicts defining parameter groups. + lr: Learning rate + betas: Coefficients used for computing running averages of gradient and its norm. + eps: Term added to the denominator to improve numerical stability. + weight_decay: Weight decay + grad_averaging: Whether apply (1-beta2) to grad when calculating running averages of gradient. + max_grad_norm: Value used to clip global grad norm. + trust_clip: Enable LAMBC trust ratio clipping. + always_adapt: Apply adaptive learning rate to 0.0 weight decay parameter. + caution: Apply caution. """ def __init__( self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, 0.999), - eps=1e-6, - weight_decay=0.01, - grad_averaging=True, - max_grad_norm=1.0, - trust_clip=False, - always_adapt=False, + params: ParamsT, + lr: float = 1e-3, + bias_correction: bool = True, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.01, + grad_averaging: bool = True, + max_grad_norm: Optional[float] = 1.0, + trust_clip: bool = False, + always_adapt: bool = False, + caution: bool = False, ): defaults = dict( lr=lr, @@ -107,9 +105,15 @@ def __init__( max_grad_norm=max_grad_norm, trust_clip=trust_clip, always_adapt=always_adapt, + caution=caution, ) super().__init__(params, defaults) + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('caution', False) + def _get_clip_grad_norm(self): max_grad_norm = self.defaults['max_grad_norm'] if max_grad_norm is None: @@ -187,6 +191,12 @@ def step(self, closure=None): denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) update = (exp_avg / bias_correction1).div_(denom) + if group['caution']: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (update * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + update.mul_(mask) + weight_decay = group['weight_decay'] if weight_decay != 0: update.add_(p, alpha=weight_decay) diff --git a/timm/optim/laprop.py b/timm/optim/laprop.py index fd760c398c..cdb30f587f 100644 --- a/timm/optim/laprop.py +++ b/timm/optim/laprop.py @@ -12,9 +12,13 @@ } """ +from typing import Tuple + from torch.optim import Optimizer import torch +from ._types import ParamsT + class LaProp(Optimizer): """ LaProp Optimizer @@ -23,11 +27,12 @@ class LaProp(Optimizer): """ def __init__( self, - params, - lr=4e-4, - betas=(0.9, 0.999), - eps=1e-15, - weight_decay=0, + params: ParamsT, + lr: float = 4e-4, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-15, + weight_decay: float = 0., + caution: bool = False, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -42,6 +47,7 @@ def __init__( betas=betas, eps=eps, weight_decay=weight_decay, + caution=caution, ) super(LaProp, self).__init__(params, defaults) @@ -101,7 +107,14 @@ def step(self, closure=None): step_of_this_grad = grad / denom exp_avg.mul_(beta1).add_(step_of_this_grad, alpha=group['lr'] * one_minus_beta1) + if group['caution']: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (exp_avg * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + exp_avg = exp_avg * mask + p.add_(exp_avg, alpha=-step_size) + if group['weight_decay'] != 0: p.add_(p, alpha=-group['weight_decay']) diff --git a/timm/optim/lion.py b/timm/optim/lion.py index 3bcb273cac..e5847a4475 100644 --- a/timm/optim/lion.py +++ b/timm/optim/lion.py @@ -16,33 +16,35 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -from typing import List +from typing import List, Optional, Tuple import torch from torch.optim.optimizer import Optimizer +from ._types import ParamsT + class Lion(Optimizer): r"""Implements Lion algorithm.""" def __init__( self, - params, - lr=1e-4, - betas=(0.9, 0.99), - weight_decay=0.0, - maximize=False, - foreach=None, + params: ParamsT, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + caution: bool = False, + maximize: bool = False, + foreach: Optional[bool] = None, ): """Initialize the hyperparameters. Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-4) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.99)) - weight_decay (float, optional): weight decay coefficient (default: 0) + params: iterable of parameters to optimize or dicts defining parameter groups + lr: learning rate + betas: coefficients used for computing running averages of gradient and its square + weight_decay: weight decay coefficient + caution: apply caution """ if not 0.0 <= lr: @@ -55,6 +57,7 @@ def __init__( lr=lr, betas=betas, weight_decay=weight_decay, + caution=caution, foreach=foreach, maximize=maximize, ) @@ -63,6 +66,7 @@ def __init__( def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: + group.setdefault('caution', False) group.setdefault('maximize', False) group.setdefault('foreach', None) @@ -71,8 +75,7 @@ def step(self, closure=None): """Performs a single optimization step. Args: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. + closure: A closure that reevaluates the model and returns the loss. Returns: the loss. @@ -112,6 +115,7 @@ def step(self, closure=None): beta2=beta2, lr=group['lr'], weight_decay=group['weight_decay'], + caution=group['caution'], maximize=group['maximize'], foreach=group['foreach'], ) @@ -132,6 +136,7 @@ def lion( beta2: float, lr: float, weight_decay: float, + caution: bool, ): r"""Functional API that performs Lion algorithm computation. """ @@ -155,6 +160,7 @@ def lion( beta2=beta2, lr=lr, weight_decay=weight_decay, + caution=caution, maximize=maximize, ) @@ -168,6 +174,7 @@ def _single_tensor_lion( beta2: float, lr: float, weight_decay: float, + caution: bool, maximize: bool, ): for i, param in enumerate(params): @@ -183,8 +190,15 @@ def _single_tensor_lion( param.mul_(1 - lr * weight_decay) # Weight update - update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) - param.add_(torch.sign(update), alpha=-lr) + update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1).sign_() + + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (update * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + update.mul_(mask) + + param.add_(update, alpha=-lr) # Decay the momentum running average coefficient exp_avg.lerp_(grad, 1 - beta2) @@ -199,6 +213,7 @@ def _multi_tensor_lion( beta2: float, lr: float, weight_decay: float, + caution: bool, maximize: bool, ): if len(params) == 0: @@ -217,8 +232,17 @@ def _multi_tensor_lion( # Weight update updates = torch._foreach_mul(exp_avgs, beta1) torch._foreach_add_(updates, grads, alpha=1 - beta1) + updates = [u.sign_() for u in updates] + + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + masks = torch._foreach_mul(updates, grads) + masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)] + mask_scale = [m.mean() for m in masks] + torch._foreach_maximum_(mask_scale, 1e-3) + torch._foreach_div_(masks, mask_scale) + torch._foreach_mul_(updates, masks) - updates = [u.sign() for u in updates] torch._foreach_add_(params, updates, alpha=-lr) # Decay the momentum running average coefficient diff --git a/timm/optim/nadamw.py b/timm/optim/nadamw.py index c823f3d5b2..b98d8a0fda 100644 --- a/timm/optim/nadamw.py +++ b/timm/optim/nadamw.py @@ -5,44 +5,43 @@ Added multi-tensor (foreach) path. """ import math -from typing import List, Optional +from typing import List, Optional, Tuple import torch from torch import Tensor +from ._types import ParamsT + # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py. class NAdamW(torch.optim.Optimizer): - r"""Implements NAdamW algorithm. - - See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of - the NAdam algorithm (there is also a comment in the code which highlights - the only difference of NAdamW and AdamW). - For further details regarding the algorithm we refer to - `Decoupled Weight Decay Regularization`_. - - Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 1e-2) - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ + """ Implements NAdamW algorithm. + + See Table 1 in https://arxiv.org/abs/1910.05446 for the implementation of + the NAdam algorithm (there is also a comment in the code which highlights + the only difference of NAdamW and AdamW). + + For further details regarding the algorithm we refer to + - Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 + - On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ + + Args: + params: iterable of parameters to optimize or dicts defining parameter groups + lr: learning rate + betas: coefficients used for computing running averages of gradient and its square + eps: term added to the denominator to improve numerical stability + weight_decay: weight decay coefficient + caution: enable caution """ def __init__( self, - params, - lr=1e-3, - betas=(0.9, 0.999), - eps=1e-8, - weight_decay=1e-2, + params: ParamsT, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + caution: bool = False, maximize: bool = False, foreach: Optional[bool] = None, capturable: bool = False, @@ -62,6 +61,7 @@ def __init__( betas=betas, eps=eps, weight_decay=weight_decay, + caution=caution, foreach=foreach, maximize=maximize, capturable=capturable, @@ -71,11 +71,12 @@ def __init__( def __setstate__(self, state): super().__setstate__(state) state_values = list(self.state.values()) - step_is_tensor = (len(state_values) != 0) and torch.is_tensor( - state_values[0]['step']) + step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) + for group in self.param_groups: + group.setdefault('caution', False) @torch.no_grad() def step(self, closure=None): @@ -133,6 +134,7 @@ def step(self, closure=None): lr=group['lr'], weight_decay=group['weight_decay'], eps=group['eps'], + caution=group['caution'], maximize=group['maximize'], capturable=group['capturable'], ) @@ -154,6 +156,7 @@ def nadamw( lr: float, weight_decay: float, eps: float, + caution: bool, maximize: bool, ) -> None: r"""Functional API that performs NAdamW algorithm computation. @@ -183,6 +186,7 @@ def nadamw( lr=lr, weight_decay=weight_decay, eps=eps, + caution=caution, maximize=maximize, capturable=capturable, ) @@ -200,6 +204,7 @@ def _single_tensor_nadamw( lr: float, weight_decay: float, eps: float, + caution: bool, maximize: bool, capturable: bool ): @@ -238,6 +243,14 @@ def _single_tensor_nadamw( exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) + + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + # FIXME not 100% sure if this remains capturable? + mask = (exp_avg * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + exp_avg.mul_(mask) + param.addcdiv_(exp_avg, denom) else: step = step_t.item() @@ -246,11 +259,17 @@ def _single_tensor_nadamw( step_size = lr / bias_correction1 bias_correction2_sqrt = math.sqrt(bias_correction2) - # Only difference between NAdamW and AdamW in this implementation. + # Apply Nesterov. Only difference between NAdamW and AdamW in this implementation. # The official PyTorch implementation of NAdam uses a different algorithm. exp_avg = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (exp_avg * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + exp_avg.mul_(mask) + param.addcdiv_(exp_avg, denom, value=-step_size) @@ -266,6 +285,7 @@ def _multi_tensor_nadamw( lr: float, weight_decay: float, eps: float, + caution: bool, maximize: bool, capturable: bool, ): @@ -322,12 +342,22 @@ def _multi_tensor_nadamw( exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs) torch._foreach_div_( - exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size) + exp_avg_sq_sqrt, + torch._foreach_mul(bias_correction2_sqrt, step_size) ) eps_over_step_size = torch._foreach_div(step_size, eps) torch._foreach_reciprocal_(eps_over_step_size) denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size) + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + masks = torch._foreach_mul(exp_avgs, grads) + masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)] # capturable? + mask_scale = [m.mean() for m in masks] + torch._foreach_maximum_(mask_scale, 1e-3) + torch._foreach_div_(masks, mask_scale) + torch._foreach_mul_(exp_avgs, masks) + torch._foreach_addcdiv_(params, exp_avgs, denom) else: bias_correction1 = [1 - beta1 ** step.item() for step in state_steps] @@ -337,7 +367,7 @@ def _multi_tensor_nadamw( bias_correction2_sqrt = [math.sqrt(bc) for bc in bias_correction2] - # Only difference between NAdamW and AdamW in this implementation. + # Apply Nesterov. Only difference between NAdamW and AdamW in this implementation. # The official PyTorch implementation of NAdam uses a different algorithm. exp_avgs = torch._foreach_mul(exp_avgs, beta1) torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) @@ -346,4 +376,13 @@ def _multi_tensor_nadamw( torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) denom = torch._foreach_add(exp_avg_sq_sqrt, eps) + if caution: + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + masks = torch._foreach_mul(exp_avgs, grads) + masks = [(m > 0).to(g.dtype) for m, g in zip(masks, grads)] + mask_scale = [m.mean() for m in masks] + torch._foreach_maximum_(mask_scale, 1e-3) + torch._foreach_div_(masks, mask_scale) + torch._foreach_mul_(exp_avgs, masks) + torch._foreach_addcdiv_(params, exp_avgs, denom, step_size) diff --git a/timm/optim/rmsprop_tf.py b/timm/optim/rmsprop_tf.py index 8511b3b482..07b0279c85 100644 --- a/timm/optim/rmsprop_tf.py +++ b/timm/optim/rmsprop_tf.py @@ -10,6 +10,8 @@ import torch from torch.optim import Optimizer +from ._types import ParamsT + class RMSpropTF(Optimizer): """Implements RMSprop algorithm (TensorFlow style epsilon) @@ -28,34 +30,31 @@ class RMSpropTF(Optimizer): The centered version first appears in `Generating Sequences With Recurrent Neural Networks `_. - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-2) - momentum (float, optional): momentum factor (default: 0) - alpha (float, optional): smoothing (decay) constant (default: 0.9) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-10) - centered (bool, optional) : if ``True``, compute the centered RMSProp, - the gradient is normalized by an estimation of its variance - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - decoupled_decay (bool, optional): decoupled weight decay as per https://arxiv.org/abs/1711.05101 - lr_in_momentum (bool, optional): learning rate scaling is included in the momentum buffer - update as per defaults in Tensorflow - + Args: + params: iterable of parameters to optimize or dicts defining parameter groups + lr: learning rate + momentum: momentum factor + alpha: smoothing (decay) constant + eps: term added to the denominator to improve numerical stability + centered: if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance + weight_decay: weight decay (L2 penalty) (default: 0) + decoupled_decay: decoupled weight decay as per https://arxiv.org/abs/1711.05101 + lr_in_momentum: learning rate scaling is included in the momentum buffer update as per defaults in Tensorflow + caution: apply caution """ def __init__( self, - params, - lr=1e-2, - alpha=0.9, - eps=1e-10, - weight_decay=0, - momentum=0., - centered=False, - decoupled_decay=False, - lr_in_momentum=True, + params: ParamsT, + lr: float = 1e-2, + alpha: float = 0.9, + eps: float = 1e-10, + weight_decay: float = 0, + momentum: float = 0., + centered: bool = False, + decoupled_decay: bool = False, + lr_in_momentum: bool = True, + caution: bool = False, ): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -77,6 +76,7 @@ def __init__( weight_decay=weight_decay, decoupled_decay=decoupled_decay, lr_in_momentum=lr_in_momentum, + caution=caution, ) super(RMSpropTF, self).__init__(params, defaults) @@ -85,6 +85,7 @@ def __setstate__(self, state): for group in self.param_groups: group.setdefault('momentum', 0) group.setdefault('centered', False) + group.setdefault('caution', False) @torch.no_grad() def step(self, closure=None): @@ -142,13 +143,25 @@ def step(self, closure=None): if group['momentum'] > 0: buf = state['momentum_buffer'] - # Tensorflow accumulates the LR scaling in the momentum buffer + buf.mul_(group['momentum']) + + def _apply_caution(_m, _g): + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (_m * _g > 0).to(_g.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + return _m * mask + if group['lr_in_momentum']: - buf.mul_(group['momentum']).addcdiv_(grad, avg, value=group['lr']) + # Tensorflow accumulates the LR scaling in the momentum buffer + buf.addcdiv_(grad, avg, value=group['lr']) + if group['caution']: + buf = _apply_caution(buf, grad) p.add_(-buf) else: # PyTorch scales the param update by LR - buf.mul_(group['momentum']).addcdiv_(grad, avg) + buf.addcdiv_(grad, avg) + if group['caution']: + buf = _apply_caution(buf, grad) p.add_(buf, alpha=-group['lr']) else: p.addcdiv_(grad, avg, value=-group['lr']) diff --git a/timm/optim/sgdw.py b/timm/optim/sgdw.py index c5b44063d6..b771c43c67 100644 --- a/timm/optim/sgdw.py +++ b/timm/optim/sgdw.py @@ -1,4 +1,5 @@ -from functools import update_wrapper, wraps +from typing import List, Optional + import torch from torch import Tensor from torch.optim.optimizer import Optimizer @@ -8,7 +9,7 @@ except ImportError: has_recent_pt = False -from typing import List, Optional +from ._types import ParamsT __all__ = ['SGDW', 'sgdw'] @@ -16,13 +17,14 @@ class SGDW(Optimizer): def __init__( self, - params, - lr=1e-3, - momentum=0, - dampening=0, - weight_decay=0, - nesterov=False, + params: ParamsT, + lr: float = 1e-3, + momentum: float = 0., + dampening: float = 0., + weight_decay: float = 0., + nesterov: bool = False, *, + caution: bool = False, maximize: bool = False, foreach: Optional[bool] = None, differentiable: bool = False, @@ -40,6 +42,7 @@ def __init__( dampening=dampening, weight_decay=weight_decay, nesterov=nesterov, + caution=caution, maximize=maximize, foreach=foreach, differentiable=differentiable, @@ -51,18 +54,19 @@ def __init__( def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: + group.setdefault('caution', False) group.setdefault('nesterov', False) group.setdefault('maximize', False) group.setdefault('foreach', None) group.setdefault('differentiable', False) - def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list): + def _init_group(self, group, params_with_grad, grads, momentum_buffer_list): has_sparse_grad = False for p in group['params']: if p.grad is not None: params_with_grad.append(p) - d_p_list.append(p.grad) + grads.append(p.grad) if p.grad.is_sparse: has_sparse_grad = True @@ -91,20 +95,21 @@ def step(self, closure=None): for group in self.param_groups: params_with_grad = [] - d_p_list = [] + grads = [] momentum_buffer_list = [] - has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list) + has_sparse_grad = self._init_group(group, params_with_grad, grads, momentum_buffer_list) sgdw( params_with_grad, - d_p_list, + grads, momentum_buffer_list, weight_decay=group['weight_decay'], momentum=group['momentum'], lr=group['lr'], dampening=group['dampening'], nesterov=group['nesterov'], + caution=group['caution'], maximize=group['maximize'], has_sparse_grad=has_sparse_grad, foreach=group['foreach'], @@ -120,7 +125,7 @@ def step(self, closure=None): def sgdw( params: List[Tensor], - d_p_list: List[Tensor], + grads: List[Tensor], momentum_buffer_list: List[Optional[Tensor]], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim @@ -132,6 +137,7 @@ def sgdw( lr: float, dampening: float, nesterov: bool, + caution: bool, maximize: bool ): r"""Functional API that performs SGD algorithm computation. @@ -159,13 +165,14 @@ def sgdw( func( params, - d_p_list, + grads, momentum_buffer_list, weight_decay=weight_decay, momentum=momentum, lr=lr, dampening=dampening, nesterov=nesterov, + caution=caution, has_sparse_grad=has_sparse_grad, maximize=maximize, ) @@ -173,7 +180,7 @@ def sgdw( def _single_tensor_sgdw( params: List[Tensor], - d_p_list: List[Tensor], + grads: List[Tensor], momentum_buffer_list: List[Optional[Tensor]], *, weight_decay: float, @@ -181,11 +188,12 @@ def _single_tensor_sgdw( lr: float, dampening: float, nesterov: bool, + caution: bool, maximize: bool, has_sparse_grad: bool ): for i, param in enumerate(params): - d_p = d_p_list[i] if not maximize else -d_p_list[i] + grad = grads[i] if not maximize else -grads[i] param.mul_(1. - lr * weight_decay) @@ -193,17 +201,25 @@ def _single_tensor_sgdw( buf = momentum_buffer_list[i] if buf is None: - buf = torch.clone(d_p).detach() + buf = torch.clone(grad).detach() momentum_buffer_list[i] = buf else: - buf.mul_(momentum).add_(d_p, alpha=1 - dampening) - - if nesterov: - d_p = d_p.add(buf, alpha=momentum) + buf.mul_(momentum).add_(grad, alpha=1 - dampening) + + if caution: + if nesterov: + buf = grad.add(buf, alpha=momentum) + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + mask = (buf * grad > 0).to(grad.dtype) + mask.div_(mask.mean().clamp_(min=1e-3)) + grad = buf * mask else: - d_p = buf + if nesterov: + grad = grad.add(buf, alpha=momentum) + else: + grad = buf - param.add_(d_p, alpha=-lr) + param.add_(grad, alpha=-lr) def _multi_tensor_sgdw( @@ -216,6 +232,7 @@ def _multi_tensor_sgdw( lr: float, dampening: float, nesterov: bool, + caution: bool, maximize: bool, has_sparse_grad: bool ): @@ -258,10 +275,22 @@ def _multi_tensor_sgdw( bufs.append(buf) - if nesterov: - torch._foreach_add_(device_grads, bufs, alpha=momentum) + if caution: + if nesterov: + # Can't do nesterov in-place if we want to compare against orig grad for caution + bufs = torch._foreach_add(device_grads, bufs, alpha=momentum) + # Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085 + masks = torch._foreach_mul(bufs, device_grads) + masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)] + mask_scale = [m.mean() for m in masks] + torch._foreach_maximum_(mask_scale, 1e-3) + torch._foreach_div_(masks, mask_scale) + device_grads = torch._foreach_mul(bufs, masks) else: - device_grads = bufs + if nesterov: + torch._foreach_add_(device_grads, bufs, alpha=momentum) + else: + device_grads = bufs if not device_has_sparse_grad: torch._foreach_add_(device_params, device_grads, alpha=-lr) From b0a121bed055092dc7f82067ac0a58aa879e54e9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 28 Nov 2024 13:39:44 -0800 Subject: [PATCH 2/3] Work around _foreach_maximum issue, need scalar other support --- timm/optim/lion.py | 7 +++++-- timm/optim/nadamw.py | 7 ++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/timm/optim/lion.py b/timm/optim/lion.py index e5847a4475..980a071322 100644 --- a/timm/optim/lion.py +++ b/timm/optim/lion.py @@ -141,8 +141,11 @@ def lion( r"""Functional API that performs Lion algorithm computation. """ if foreach is None: - # Placeholder for more complex foreach logic to be added when value is not set - foreach = True + try: + # cannot do foreach if this overload doesn't exist when caution enabled + foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads() + except: + foreach = False if foreach and torch.jit.is_scripting(): raise RuntimeError('torch.jit.script not supported with foreach optimizers') diff --git a/timm/optim/nadamw.py b/timm/optim/nadamw.py index b98d8a0fda..17eb6fd091 100644 --- a/timm/optim/nadamw.py +++ b/timm/optim/nadamw.py @@ -169,7 +169,12 @@ def nadamw( ' singleton tensors') if foreach is None: - foreach = True + try: + # cannot do foreach if this overload doesn't exist when caution enabled + foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads() + except: + foreach = False + if foreach and not torch.jit.is_scripting(): func = _multi_tensor_nadamw else: From 9b27f848760ceff70ee6a70771be6a137e11790b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 28 Nov 2024 13:46:17 -0800 Subject: [PATCH 3/3] To be technically correct, need to check the in-place _ ver of op --- timm/optim/lion.py | 2 +- timm/optim/nadamw.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/optim/lion.py b/timm/optim/lion.py index 980a071322..1860723203 100644 --- a/timm/optim/lion.py +++ b/timm/optim/lion.py @@ -143,7 +143,7 @@ def lion( if foreach is None: try: # cannot do foreach if this overload doesn't exist when caution enabled - foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads() + foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum_.overloads() except: foreach = False diff --git a/timm/optim/nadamw.py b/timm/optim/nadamw.py index 17eb6fd091..d9933026c6 100644 --- a/timm/optim/nadamw.py +++ b/timm/optim/nadamw.py @@ -171,7 +171,7 @@ def nadamw( if foreach is None: try: # cannot do foreach if this overload doesn't exist when caution enabled - foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads() + foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum_.overloads() except: foreach = False