Skip to content

Commit

Permalink
✨Introduced MADGRAD
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Apr 7, 2021
1 parent 7a40dcc commit 4466c9f
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions cflearn/modules/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,77 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]:
return loss


@register_optimizer("madgrad")
class MADGRAD(Optimizer):
def __init__(
self,
params: Iterable[torch.Tensor],
lr: float,
momentum: float = 0.0,
weight_decay: float = 0.0,
eps: float = 1.0e-6,
):
defaults = dict(lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay)
super().__init__(params, defaults)

def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
loss = None
if closure is not None:
loss = closure()

if "k" not in self.state:
self.state["k"] = torch.tensor([0], dtype=torch.long)
k = self.state["k"].item()

for group in self.param_groups:
eps = group["eps"]
lr = group["lr"] + eps
decay = group["weight_decay"]
momentum = group["momentum"]

ck = 1.0 - momentum
lb = lr * math.sqrt(k + 1)

for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
msg = "MADGRAD optimizer does not support sparse gradients"
raise RuntimeError(msg)

state = self.state[p]
if "grad_sum_sq" not in state:
state["grad_sum_sq"] = torch.zeros_like(p.data).detach()
state["s"] = torch.zeros_like(p.data).detach()
if momentum != 0.0:
state["x0"] = torch.clone(p.data).detach()

grad_sum_sq = state["grad_sum_sq"]
s = state["s"]

if decay:
p.data.mul_(1.0 - lr * decay)

if momentum == 0.0:
rms = grad_sum_sq.pow(1.0 / 3.0).add_(eps)
x0 = p.data.addcdiv(s, rms, value=1.0)
else:
x0 = state["x0"]

grad_sum_sq.addcmul_(grad, grad, value=lb)
rms = grad_sum_sq.pow(1.0 / 3.0).add_(eps)

s.data.add_(grad, alpha=lb)

if momentum == 0.0:
p.data.copy_(x0.addcdiv(s, rms, value=-1.0))
else:
z = x0.addcdiv(s, rms, value=-1.0)
p.data.mul_(1.0 - ck).add_(z, alpha=ck)

self.state["k"] += 1
return loss


__all__ = ["optimizer_dict", "register_optimizer"]

0 comments on commit 4466c9f

Please sign in to comment.