-
Notifications
You must be signed in to change notification settings - Fork 0
/
LBFGS.py
246 lines (208 loc) · 8.84 KB
/
LBFGS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
class LBFGS(Optimizer):
r"""Implements L-BFGS algorithm, heavily inspired by `minFunc
<https://www.cs.ubc.ca/~schmidtm/Software/minFunc.html>`.
"""
def __init__(self, params, lr = 1, max_iter = 20, max_eval = None, tolerance_grad = 1e-7,
tolerance_change = 1e-9, history_size = 100, line_search_fn = None):
if max_eval is None:
max_eval = max_iter * 5 // 4
defaults = dict(lr = lr, max_iter = max_iter, max_eval = max_eval, tolerance_grad = tolerance_grad,
tolerance_change = tolerance_change, history_size = history_size, line_search_fn = line_search_fn)
super(LBFGS, self).__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError("LBFGS doesn't support per-parameter options (parameter groups)")
self._params = self.param_groups[0]['params']
self._numel_cache = None
def _numel(self):
if self._numel_cache is None:
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
return self._numel_cache
def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.view(-1)
views.append(view)
return torch.cat(views, 0)
def _add_grad(self, step_size, update):
offset = 0
for p in self._params:
numel = p.numel()
# view as to avoid deprecated pointwise semantics
p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
offset += numel
assert offset == self._numel()
def _clone_param(self):
return [p.clone() for p in self._params]
def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
p.data.copy_(pdata)
def _directional_evaluate(self, closure, x, t, d):
self._add_grad(t, d)
loss = float(closure())
flat_grad = self._gather_flat_grad()
self._set_param(x)
return loss, flat_grad
def step(self, closure):
"""Performs a single optimization step.
Arguments:
closure (callable): A closure that reevaluates the model
and returns the loss.
"""
assert len(self.param_groups) == 1
group = self.param_groups[0]
lr = group['lr']
max_iter = group['max_iter']
max_eval = group['max_eval']
tolerance_grad = group['tolerance_grad']
tolerance_change = group['tolerance_change']
line_search_fn = group['line_search_fn']
history_size = group['history_size']
# NOTE: LBFGS has only global state, but we register it as state for
# the first param, because this helps with casting in load_state_dict
state = self.state[self._params[0]]
state.setdefault('func_evals', 0)
state.setdefault('n_iter', 0)
# evaluate initial f(x) and df/dx
orig_loss = closure()
loss = float(orig_loss)
current_evals = 1
state['func_evals'] += 1
flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad
# optimal condition
if opt_cond:
return orig_loss
# tensors cached in state (for tracing)
d = state.get('d')
t = state.get('t')
old_dirs = state.get('old_dirs')
old_stps = state.get('old_stps')
ro = state.get('ro')
H_diag = state.get('H_diag')
prev_flat_grad = state.get('prev_flat_grad')
prev_loss = state.get('prev_loss')
n_iter = 0
# optimize for a max of max_iter iterations
while n_iter < max_iter:
# keep track of nb of iterations
n_iter += 1
state['n_iter'] += 1
############################################################
# compute gradient descent direction
############################################################
if state['n_iter'] == 1:
d = flat_grad.neg()
old_dirs = []
old_stps = []
ro = []
H_diag = 1
else:
# do lbfgs update (update memory)
y = flat_grad.sub(prev_flat_grad)
s = d.mul(t)
ys = y.dot(s) # y*s
if ys > 1e-10:
# updating memory
if len(old_dirs) == history_size:
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
ro.pop(0)
# store new direction/step
old_dirs.append(y)
old_stps.append(s)
ro.append(1. / ys)
# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
# compute the approximate (L-BFGS) inverse Hessian
# multiplied by the gradient
num_old = len(old_dirs)
if 'al' not in state:
state['al'] = [None] * history_size
al = state['al']
# iteration in L-BFGS loop collapsed to use just one buffer
q = flat_grad.neg()
for i in range(num_old - 1, -1, -1):
al[i] = old_stps[i].dot(q) * ro[i]
q.add_(-al[i], old_dirs[i])
# multiply by initial Hessian
# r/d is the final direction
d = r = torch.mul(q, H_diag)
for i in range(num_old):
be_i = old_dirs[i].dot(r) * ro[i]
r.add_(al[i] - be_i, old_stps[i])
if prev_flat_grad is None:
prev_flat_grad = flat_grad.clone()
else:
prev_flat_grad.copy_(flat_grad)
prev_loss = loss
############################################################
# compute step length
############################################################
# reset initial guess for step size
if state['n_iter'] == 1:
t = min(1., 1. / flat_grad.abs().sum()) * lr
else:
t = lr
# directional derivative
gtd = flat_grad.dot(d) # g * d
# directional derivative is below tolerance
if gtd > -tolerance_change:
break
# optional line search: user function
ls_func_evals = 0
if line_search_fn is not None:
# perform line search, using user function
if line_search_fn != "strong_wolfe":
raise RuntimeError("only 'strong_wolfe' is supported")
else:
x_init = self._clone_param()
def obj_func(x, t, d):
return self._directional_evaluate(closure, x, t, d)
loss, flat_grad, t, ls_func_evals = _strong_wolfe(
obj_func, x_init, t, d, loss, flat_grad, gtd)
self._add_grad(t, d)
opt_cond = flat_grad.abs().max() <= tolerance_grad
else:
# no line search, simply move with fixed-step
self._add_grad(t, d)
if n_iter != max_iter:
# re-evaluate function only if not in last iteration
# the reason we do this: in a stochastic setting,
# no use to re-evaluate that function here
loss = float(closure())
flat_grad = self._gather_flat_grad()
opt_cond = flat_grad.abs().max() <= tolerance_grad
ls_func_evals = 1
# update func eval
current_evals += ls_func_evals
state['func_evals'] += ls_func_evals
############################################################
# check conditions
############################################################
if n_iter == max_iter:
break
if current_evals >= max_eval:
break
# optimal condition
if opt_cond:
break
# lack of progress
if d.mul(t).abs().max() <= tolerance_change:
break
if abs(loss - prev_loss) < tolerance_change:
break
state['d'] = d
state['t'] = t
state['old_dirs'] = old_dirs
state['old_stps'] = old_stps
state['ro'] = ro
state['H_diag'] = H_diag
state['prev_flat_grad'] = prev_flat_grad
state['prev_loss'] = prev_loss
return orig_loss