forked from varshakishore/dsi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimizer.py
140 lines (108 loc) · 4.37 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
from functools import reduce
from torch.optim.optimizer import Optimizer
class ArmijoSGD(Optimizer):
"""Implements ArmijoSGD algorithm, heavily inspired by `minFunc
<https://en.wikipedia.org/wiki/Backtracking_line_search#CITEREFArmijo1966>`_.
Args:
lr (float): learning rate (default: 1)
max_iter (int): maximal number of iterations per optimization step
(default: 20)
max_eval (int): maximal number of function evaluations per optimization
step (default: max_iter * 1.25).
tolerance_grad (float): termination tolerance on first order optimality
(default: 1e-5).
tolerance_change (float): termination tolerance on function
value/parameter changes (default: 1e-9).
history_size (int): update history size (default: 100).
line_search_fn (str): either 'strong_wolfe' or None (default: None).
"""
def __init__(self,
params,
lr=1,
max_iter=1000,
tau=.5,
c=.5):
defaults = dict(
lr=lr,
max_iter=max_iter,
tau=tau,
c=c)
super(ArmijoSGD, self).__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError("LBFGS doesn't support per-parameter options "
"(parameter groups)")
self._params = self.param_groups[0]['params']
self._numel_cache = None
self.curr_lr = self.param_groups[0]['lr']
def _clone_param(self):
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.copy_(pdata)
@torch.no_grad()
def step(self, closure):
"""Performs a single optimization step.
Args:
closure (Callable): A closure that reevaluates the model
and returns the loss.
"""
assert len(self.param_groups) == 1
# Make sure the closure is always called with grad enabled
closure = torch.enable_grad()(closure)
group = self.param_groups[0]
lr = group['lr']
max_iter = group['max_iter']
tau = group['tau']
c = group['c']
# evaluate initial f(x) and df/dx
orig_loss = closure()
# f(x): original loss value
orig_loss_float = float(orig_loss)
# p: descent direction is negative gradient
descent_direction = -1*self._params[0].grad.detach().clone()
# m: negative norm of gradient
m = torch.dot(descent_direction, self._params[0].grad)
# t: product of hyperparameter, c , and m
t = -1*c*m
x_init = self._clone_param()[0]
x = self._params[0]
n_iter = 0
while n_iter < max_iter:
# keep track of nb of iterations
alpha = self.curr_lr*((1/tau)**n_iter)
# Gradient step
x.copy_(x_init + alpha*descent_direction)
# Updated loss
loss = closure()
############################################################
# check conditions
############################################################
# Armijo condition
if orig_loss_float - float(loss) < alpha*t or alpha > lr:
self.curr_lr = alpha*tau
break
n_iter += 1
n_iter = 0
while n_iter < max_iter:
# keep track of nb of iterations
alpha = self.curr_lr*(tau**n_iter)
# Gradient step
x.copy_(x_init + alpha*descent_direction)
# Updated loss
loss = closure()
# import pdb; pdb.set_trace()
############################################################
# check conditions
############################################################
# Armijo condition
if orig_loss_float - float(loss) >= alpha*t:
self.curr_lr = alpha
break
# lack of progress
# if d.mul(t).abs().max() <= tolerance_change:
# break
# if abs(loss - prev_loss) < tolerance_change:
# break
n_iter += 1
return orig_loss