-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmy_optim.py
78 lines (62 loc) · 2.46 KB
/
my_optim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
class PolyOptimizer(optim.SGD):
def __init__(self, params, lr, weight_decay, max_step, momentum=0.9):
super().__init__(params, lr, weight_decay)
self.param_groups = params
self.global_step = 0
self.max_step = max_step
self.momentum = momentum
self.__initial_lr = [group['lr'] for group in self.param_groups]
def step(self, closure=None):
if self.global_step < self.max_step:
lr_mult = (1 - self.global_step / self.max_step) ** self.momentum
for i in range(len(self.param_groups)):
self.param_groups[i]['lr'] = self.__initial_lr[i] * lr_mult
super().step(closure)
self.global_step += 1
def lr_poly(base_lr, iter,max_iter,power=0.9):
return base_lr*((1-float(iter)/max_iter)**(power))
def reduce_lr_poly(args, optimizer, global_iter, max_iter):
base_lr = args.lr
for g in optimizer.param_groups:
g['lr'] = lr_poly(base_lr=base_lr, iter=global_iter, max_iter=max_iter, power=0.9)
def get_optimizer(args, model):
lr = args.lr
opt = optim.SGD(params=[para for name, para in model.named_parameters() if 'features' not in name], lr=lr, momentum=0.9, weight_decay=0.0005)
return opt
def get_adam(args, model):
lr = args.lr
opt = optim.Adam(params=model.parameters(), lr =lr, weight_decay=0.0005)
return opt
def reduce_lr(args, optimizer, epoch, factor=0.1):
values = args.decay_points.strip().split(',')
try:
change_points = map(lambda x: int(x.strip()), values)
except ValueError:
change_points = None
if change_points is not None and epoch in change_points:
for g in optimizer.param_groups:
g['lr'] = g['lr']*factor
print(epoch, g['lr'])
return True
def adjust_lr(args, optimizer, epoch):
if 'cifar' in args.dataset:
change_points = [80, 120, 160]
elif 'indoor' in args.dataset:
change_points = [60, 80, 100]
elif 'dog' in args.dataset:
change_points = [60, 80, 100]
elif 'voc' in args.dataset:
change_points = [30, 40]
else:
change_points = None
if change_points is not None:
change_points = np.array(change_points)
pos = np.sum(epoch > change_points)
lr = args.lr * (0.1**pos)
else:
lr = args.lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr