-
Notifications
You must be signed in to change notification settings - Fork 144
/
optimizer.py
131 lines (118 loc) · 4.89 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Collections of utilities related to optimization."""
from bisect import bisect_right
import torch
def update_ema(model, model_ema, decay):
"""Apply exponential moving average update.
Based on the implementation in the MDETR repo: https://bit.ly/3J5fspI.
The weights are updated in-place as follow:
w_ema = w_ema * decay + (1 - decay) * w
Args:
model: active model that is being optimized
model_ema: running average model
decay: exponential decay parameter
"""
with torch.no_grad():
if hasattr(model, "module"):
# unwrapping DDP
model = model.module
msd = model.state_dict()
for k, ema_v in model_ema.state_dict().items():
model_v = msd[k].detach()
ema_v.copy_(ema_v * decay + (1.0 - decay) * model_v)
def adjust_learning_rate(
optimizer,
epoch: int,
curr_step: int,
num_training_steps: int,
args,
):
"""Adjust the lr according to the schedule.
Based on the implementation in the MDETR repo: https://bit.ly/3J5fspI.
Args:
Optimizer: torch optimizer to update.
epoch(int): number of the current epoch.
curr_step(int): number of optimization step taken so far.
num_training_step(int): total number of optimization steps.
args: additional training dependent args:
- lr_drop(int): number of epochs before dropping the learning rate.
- fraction_warmup_steps(float) fraction of steps over which the lr will be increased to its peak.
- lr(float): base learning rate
- lr_backbone(float): learning rate of the backbone
- text_encoder_backbone(float): learning rate of the text encoder
- schedule(str): the requested learning rate schedule:
"step": all lrs divided by 10 after lr_drop epochs
"multistep": divided by 2 after lr_drop epochs, then by 2 after every 50 epochs
"linear_with_warmup": same as "step" for backbone + transformer, but for the text encoder, linearly
increase for a fraction of the training, then linearly decrease back to 0.
"all_linear_with_warmup": same as "linear_with_warmup" for all learning rates involved.
"""
num_warmup_steps: int = round(args.fraction_warmup_steps * num_training_steps)
if args.schedule == "step":
gamma = 0.1 ** (epoch // args.lr_drop)
text_encoder_gamma = gamma
elif args.schedule == "multistep":
milestones = list(range(args.lr_drop, args.epochs, 50))
gamma = 0.5 ** bisect_right(milestones, epoch)
text_encoder_gamma = gamma
elif args.schedule == "linear_with_warmup":
gamma = 0.1 ** (epoch // args.lr_drop)
if curr_step < num_warmup_steps:
text_encoder_gamma = float(curr_step) / float(max(1, num_warmup_steps))
else:
text_encoder_gamma = max(
0.0,
float(num_training_steps - curr_step)
/ float(max(1, num_training_steps - num_warmup_steps)),
)
elif args.schedule == "all_linear_with_warmup":
if curr_step < num_warmup_steps:
text_encoder_gamma = float(curr_step) / float(max(1, num_warmup_steps))
else:
text_encoder_gamma = max(
0.0,
float(num_training_steps - curr_step)
/ float(max(1, num_training_steps - num_warmup_steps)),
)
gamma = text_encoder_gamma
else:
raise NotImplementedError
base_lrs = [args.lr, args.lr_backbone, args.text_encoder_lr]
gammas = [gamma, gamma, text_encoder_gamma]
assert len(optimizer.param_groups) == len(base_lrs)
for param_group, lr, gamma_group in zip(optimizer.param_groups, base_lrs, gammas):
param_group["lr"] = lr * gamma_group
def build_optimizer(model, args):
param_dicts = [
{
"params": [
p
for n, p in model.named_parameters()
if "backbone" not in n and "text_encoder" not in n and p.requires_grad
]
},
{
"params": [
p
for n, p in model.named_parameters()
if "backbone" in n and p.requires_grad
],
"lr": args.lr_backbone,
},
{
"params": [
p
for n, p in model.named_parameters()
if "text_encoder" in n and p.requires_grad
],
"lr": args.text_encoder_lr,
},
]
optimizer = torch.optim.AdamW(
param_dicts, lr=args.lr, weight_decay=args.weight_decay
)
return optimizer