Skip to content

Commit

Permalink
✨Introduced Ranger
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Apr 7, 2021
1 parent 4466c9f commit acdeec4
Showing 1 changed file with 114 additions and 0 deletions.
114 changes: 114 additions & 0 deletions cflearn/modules/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import torch

from typing import *
Expand Down Expand Up @@ -133,4 +134,117 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
return loss


@register_optimizer("ranger")
class Ranger(Optimizer):
def __init__(
self,
params: Iterable[torch.Tensor],
lr: float,
alpha: float = 0.5,
k: int = 6,
n_sma_threshold: float = 5.0, # 4.0
betas: Tuple[float, float] = (0.95, 0.999), # (0.90, 0.999)
eps: float = 1e-5,
weight_decay: float = 0.0,
use_gc: bool = True,
gc_conv_only: bool = False,
):
defaults = dict(
lr=lr,
alpha=alpha,
k=k,
n_sma_threshhold=n_sma_threshold,
betas=betas,
eps=eps,
weight_decay=weight_decay,
step_counter=0,
)
super().__init__(params, defaults)
self.n_sma_threshold = n_sma_threshold
self.alpha = alpha

self.radam_buffer = [[None, None, None] for _ in range(10)]
self.use_gc = use_gc
self.gc_gradient_threshold = 3 if gc_conv_only else 1

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
msg = "Ranger optimizer does not support sparse gradients"
raise RuntimeError(msg)

p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) != 0:
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
else:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
state["slow_buffer"] = torch.empty_like(p.data)
state["slow_buffer"].copy_(p.data)

exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]

if grad.dim() > self.gc_gradient_threshold:
grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))

state["step"] += 1

exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)

buffered = self.radam_buffer[int(state['step'] % 10)]

if state["step"] == buffered[0]:
n_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state["step"]
beta2_t = beta2 ** state["step"]
n_sma_max = 2.0 / (1.0 - beta2) - 1.0
n_sma = n_sma_max - 2.0 * state["step"] * beta2_t / (1.0 - beta2_t)
buffered[1] = n_sma
if n_sma <= self.n_sma_threshold:
step_size = 1.0 / (1.0 - beta1 ** state["step"])
else:
step_size = math.sqrt(
(1.0 - beta2_t) * (n_sma - 4.0)
/ (n_sma_max - 4.0) * (n_sma - 2.0)
/ n_sma * n_sma_max
/ (n_sma_max - 2.0)
) / (1.0 - beta1 ** state["step"])
buffered[2] = step_size

if group["weight_decay"] != 0:
p_data_fp32.add_(
p_data_fp32,
alpha=-group["weight_decay"] * group["lr"],
)

if n_sma <= self.n_sma_threshold:
p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])
else:
denom = exp_avg_sq.sqrt().add_(group["eps"])
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])

p.data.copy_(p_data_fp32)

if state["step"] % group["k"] == 0:
slow_p = state["slow_buffer"]
slow_p.add_(p.data - slow_p, alpha=self.alpha)
p.data.copy_(slow_p)

return loss


__all__ = ["optimizer_dict", "register_optimizer"]

0 comments on commit acdeec4

Please sign in to comment.