Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dadapt optimizers #1520

Open
wants to merge 2 commits into
base: pytorch
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci-cd/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ function test() {
alf.networks.q_networks_test \
alf.networks.relu_mlp_test \
alf.networks.value_networks_test \
alf.optimizers.dadapt_optimizers_test \
alf.optimizers.nero_plus_test \
alf.optimizers.optimizers_test \
alf.optimizers.trusted_updater_test \
Expand Down
2 changes: 2 additions & 0 deletions alf/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .optimizers import AdamW
from .optimizers import SGD
from .optimizers import NeroPlus
from .optimizers import DAdaptSGD
from .optimizers import DAdaptAdam

from typing import Any
Optimizer = Any
263 changes: 263 additions & 0 deletions alf/optimizers/dadapt_adam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree:
# https://github.com/facebookresearch/dadaptation/blob/main/LICENSE

import math
from typing import TYPE_CHECKING, Any, Callable, Optional

import torch
import pdb
import logging
import os
import torch.distributed as dist
from torch.optim import Optimizer

if TYPE_CHECKING:
from torch.optim.optimizer import _params_t
else:
_params_t = Any


class DAdaptAdam(Optimizer):
r"""
Implements Adam with D-Adaptation automatic step-sizes.
Leave LR set to 1 unless you encounter instability.

Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining parameter groups.
lr (float):
Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate.
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float):
Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8).
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0).
log_every (int):
Log using print every k steps, default 0 (no logging).
decouple (boolean):
Use AdamW style decoupled weight decay
use_bias_correction (boolean):
Turn on Adam's bias correction. Off by default.
d0 (float):
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
growth_rate (float):
prevent the D estimate from growing faster than this multiplicative rate.
Default is inf, for unrestricted. Values like 1.02 give a kind of learning
rate warmup effect.
fsdp_in_use (bool):
If you're using sharded parameters, this should be set to True. The optimizer
will attempt to auto-detect this, but if you're using an implementation other
than PyTorch's builtin version, the auto-detection won't work.
"""

def __init__(self,
params,
lr=1.0,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
log_every=0,
decouple=False,
use_bias_correction=False,
d0=1e-6,
growth_rate=float('inf'),
fsdp_in_use=False):
if not 0.0 < d0:
raise ValueError("Invalid d0 value: {}".format(d0))
if not 0.0 < lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 < eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(
betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(
betas[1]))

if decouple:
print(f"Using decoupled weight decay")

defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
d=d0,
k=0,
numerator_weighted=0.0,
log_every=log_every,
growth_rate=growth_rate,
use_bias_correction=use_bias_correction,
decouple=decouple,
fsdp_in_use=fsdp_in_use)
self.d0 = d0
super().__init__(params, defaults)

@property
def supports_memory_efficient_fp16(self):
return False

@property
def supports_flat_params(self):
return True

def step(self, closure=None):
"""Performs a single optimization step.

Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

sk_l1 = 0.0

group = self.param_groups[0]
use_bias_correction = group['use_bias_correction']
numerator_weighted = group['numerator_weighted']
beta1, beta2 = group['betas']
k = group['k']

d = group['d']
lr = max(group['lr'] for group in self.param_groups)

if use_bias_correction:
bias_correction = ((1 - beta2**(k + 1))**0.5) / (1 - beta1**
(k + 1))
else:
bias_correction = 1

dlr = d * lr * bias_correction

growth_rate = group['growth_rate']
decouple = group['decouple']
log_every = group['log_every']
fsdp_in_use = group['fsdp_in_use']

sqrt_beta2 = beta2**(0.5)

numerator_acum = 0.0

for group in self.param_groups:
decay = group['weight_decay']
k = group['k']
eps = group['eps']
group_lr = group['lr']

if group_lr not in [lr, 0.0]:
raise RuntimeError(
f"Setting different lr values in different parameter groups is only supported for values of 0"
)

for p in group['params']:
if p.grad is None:
continue
if hasattr(p, "_fsdp_flattened"):
fsdp_in_use = True

grad = p.grad.data

# Apply weight decay (coupled variant)
if decay != 0 and not decouple:
grad.add_(p.data, alpha=decay)

state = self.state[p]

# State initialization
if 'step' not in state:
state['step'] = 0
state['s'] = torch.zeros_like(p.data).detach()
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data).detach()
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data).detach()

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

s = state['s']

if group_lr > 0.0:
denom = exp_avg_sq.sqrt().add_(eps)
numerator_acum += dlr * torch.dot(
grad.flatten(),
s.div(denom).flatten()).item()

# Adam EMA updates
exp_avg.mul_(beta1).add_(grad, alpha=dlr * (1 - beta1))
exp_avg_sq.mul_(beta2).addcmul_(
grad, grad, value=1 - beta2)

s.mul_(sqrt_beta2).add_(grad, alpha=dlr * (1 - sqrt_beta2))
sk_l1 += s.abs().sum().item()

######

numerator_weighted = sqrt_beta2 * numerator_weighted + (
1 - sqrt_beta2) * numerator_acum
d_hat = d

# if we have not done any progres, return
# if we have any gradients available, will have sk_l1 > 0 (unless \|g\|=0)
if sk_l1 == 0:
return loss

if lr > 0.0:
if fsdp_in_use:
dist_tensor = torch.zeros(2).cuda()
dist_tensor[0] = numerator_weighted
dist_tensor[1] = sk_l1
dist.all_reduce(dist_tensor, op=dist.ReduceOp.SUM)
global_numerator_weighted = dist_tensor[0]
global_sk_l1 = dist_tensor[1]
else:
global_numerator_weighted = numerator_weighted
global_sk_l1 = sk_l1

d_hat = global_numerator_weighted / (
(1 - sqrt_beta2) * global_sk_l1)
d = max(d, min(d_hat, d * growth_rate))

if log_every > 0 and k % log_every == 0:
logging.info(
f"lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_l1={global_sk_l1:1.1e} numerator_weighted={global_numerator_weighted:1.1e}"
)

for group in self.param_groups:
group['numerator_weighted'] = numerator_weighted
group['d'] = d

decay = group['weight_decay']
k = group['k']
eps = group['eps']

for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data

state = self.state[p]

exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

state['step'] += 1

denom = exp_avg_sq.sqrt().add_(eps)

# Apply weight decay (decoupled variant)
if decay != 0 and decouple:
p.data.add_(p.data, alpha=-decay * dlr)

### Take step
p.data.addcdiv_(exp_avg, denom, value=-1)

group['k'] = k + 1

return loss
102 changes: 102 additions & 0 deletions alf/optimizers/dadapt_optimizers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright (c) 2023 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized
from absl import logging
import torch
import torch.nn.functional as F

import alf

from alf.optimizers import DAdaptSGD, DAdaptAdam
from alf.utils.datagen import load_mnist


class DadaptOptimizersTest(parameterized.TestCase, alf.test.TestCase):
def test_dadapt_sgd(self):
train_set, test_set = load_mnist(train_bs=256, test_bs=256)
num_classes = len(train_set.dataset.classes)
model = alf.layers.Sequential(
alf.layers.Conv2D(1, 32, 3, strides=2, padding=1),
alf.layers.Conv2D(32, 32, 3, strides=2, padding=1),
alf.layers.Conv2D(32, 32, 3, strides=2, padding=1),
alf.layers.Reshape(-1),
alf.layers.FC(
4 * 4 * 32,
num_classes,
weight_opt_args=dict(
fixed_norm=False,
l2_regularization=1e-3,
zero_mean=True,
max_norm=float('inf'))))
opt = DAdaptSGD()
opt.add_param_group(dict(params=list(model.parameters())))

for epoch in range(5):
for data, target in train_set:
logits = model(data)
loss = F.cross_entropy(logits, target)
opt.zero_grad()
loss.backward()
opt.step()
correct = 0
total = 0
for data, target in test_set:
logits = model(data)
correct += (logits.argmax(dim=1) == target).sum()
total += target.numel()
logging.info("epoch=%s loss=%s acc=%s" % (epoch, loss.item(),
correct.item()))
self.assertGreater(correct / total, 0.97)

@parameterized.parameters((True), (False))
def test_dadapt_adam(self, decouple=False):
train_set, test_set = load_mnist(train_bs=256, test_bs=256)
num_classes = len(train_set.dataset.classes)
model = alf.layers.Sequential(
alf.layers.Conv2D(1, 32, 3, strides=2, padding=1),
alf.layers.Conv2D(32, 32, 3, strides=2, padding=1),
alf.layers.Conv2D(32, 32, 3, strides=2, padding=1),
alf.layers.Reshape(-1),
alf.layers.FC(
4 * 4 * 32,
num_classes,
weight_opt_args=dict(
fixed_norm=False,
l2_regularization=1e-3,
zero_mean=True,
max_norm=float('inf'))))
opt = DAdaptAdam(decouple=decouple)
opt.add_param_group(dict(params=list(model.parameters())))

for epoch in range(5):
for data, target in train_set:
logits = model(data)
loss = F.cross_entropy(logits, target)
opt.zero_grad()
loss.backward()
opt.step()
correct = 0
total = 0
for data, target in test_set:
logits = model(data)
correct += (logits.argmax(dim=1) == target).sum()
total += target.numel()
logging.info("epoch=%s loss=%s acc=%s" % (epoch, loss.item(),
correct.item()))
self.assertGreater(correct / total, 0.97)


if __name__ == '__main__':
alf.test.main()
Loading