-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
Copy path_lbfgs.py
235 lines (202 loc) · 7.59 KB
/
_lbfgs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
from typing import Any, Callable, NamedTuple, Optional, Union
from functools import partial
import jax
import jax.numpy as jnp
from jax import lax
from jax._src.scipy.optimize.line_search import line_search
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
Array = Any
class LBFGSResults(NamedTuple):
"""Results from L-BFGS optimization
Parameters:
converged: True if minimization converged
failed: True if non-zero status and not converged
k: integer number of iterations of the main loop (optimisation steps)
nfev: integer total number of objective evaluations performed.
ngev: integer total number of jacobian evaluations
x_k: array containing the last argument value found during the search. If
the search converged, then this value is the argmin of the objective
function.
f_k: array containing the value of the objective function at `x_k`. If the
search converged, then this is the (local) minimum of the objective
function.
g_k: array containing the gradient of the objective function at `x_k`. If
the search converged the l2-norm of this tensor should be below the
tolerance.
status: integer describing the status:
0 = nominal , 1 = max iters reached , 2 = max fun evals reached
3 = max grad evals reached , 4 = insufficient progress (ftol)
5 = line search failed
ls_status: integer describing the end status of the last line search
"""
converged: Union[bool, Array]
failed: Union[bool, Array]
k: Union[int, Array]
nfev: Union[int, Array]
ngev: Union[int, Array]
x_k: Array
f_k: Array
g_k: Array
s_history: Array
y_history: Array
rho_history: Array
gamma: Union[float, Array]
status: Union[int, Array]
ls_status: Union[int, Array]
def _minimize_lbfgs(
fun: Callable,
x0: Array,
maxiter: Optional[float] = None,
norm=jnp.inf,
maxcor: int = 10,
ftol: float = 2.220446049250313e-09,
gtol: float = 1e-05,
maxfun: Optional[float] = None,
maxgrad: Optional[float] = None,
maxls: int = 20,
):
"""
Minimize a function using L-BFGS
Implements the L-BFGS algorithm from
Algorithm 7.5 from Wright and Nocedal, 'Numerical Optimization', 1999, pg. 176-185
And generalizes to complex variables from
Sorber, L., Barel, M.V. and Lathauwer, L.D., 2012.
"Unconstrained optimization of real functions in complex variables"
SIAM Journal on Optimization, 22(3), pp.879-898.
Args:
fun: function of the form f(x) where x is a flat ndarray and returns a real scalar.
The function should be composed of operations with vjp defined.
x0: initial guess
maxiter: maximum number of iterations
norm: order of norm for convergence check. Default inf.
maxcor: maximum number of metric corrections ("history size")
ftol: terminates the minimization when `(f_k - f_{k+1}) < ftol`
gtol: terminates the minimization when `|g_k|_norm < gtol`
maxfun: maximum number of function evaluations
maxgrad: maximum number of gradient evaluations
maxls: maximum number of line search steps (per iteration)
Returns:
Optimization results.
"""
d = len(x0)
dtype = jnp.dtype(x0)
# ensure there is at least one termination condition
if (maxiter is None) and (maxfun is None) and (maxgrad is None):
maxiter = d * 200
# set others to inf, such that >= is supported
if maxiter is None:
maxiter = jnp.inf
if maxfun is None:
maxfun = jnp.inf
if maxgrad is None:
maxgrad = jnp.inf
# initial evaluation
f_0, g_0 = jax.value_and_grad(fun)(x0)
state_initial = LBFGSResults(
converged=False,
failed=False,
k=0,
nfev=1,
ngev=1,
x_k=x0,
f_k=f_0,
g_k=g_0,
s_history=jnp.zeros((maxcor, d), dtype=dtype),
y_history=jnp.zeros((maxcor, d), dtype=dtype),
rho_history=jnp.zeros((maxcor,), dtype=dtype),
gamma=1.,
status=0,
ls_status=0,
)
def cond_fun(state: LBFGSResults):
return (~state.converged) & (~state.failed)
def body_fun(state: LBFGSResults):
# find search direction
p_k = _two_loop_recursion(state)
# line search
ls_results = line_search(
f=fun,
xk=state.x_k,
pk=p_k,
old_fval=state.f_k,
gfk=state.g_k,
maxiter=maxls,
)
# evaluate at next iterate
s_k = ls_results.a_k.astype(p_k.dtype) * p_k
x_kp1 = state.x_k + s_k
f_kp1 = ls_results.f_k
g_kp1 = ls_results.g_k
y_k = g_kp1 - state.g_k
rho_k_inv = jnp.real(_dot(y_k, s_k))
rho_k = jnp.reciprocal(rho_k_inv).astype(y_k.dtype)
gamma = rho_k_inv / jnp.real(_dot(jnp.conj(y_k), y_k))
# replacements for next iteration
status = 0
status = jnp.where(state.f_k - f_kp1 < ftol, 4, status)
status = jnp.where(state.ngev >= maxgrad, 3, status) # type: ignore
status = jnp.where(state.nfev >= maxfun, 2, status) # type: ignore
status = jnp.where(state.k >= maxiter, 1, status) # type: ignore
status = jnp.where(ls_results.failed, 5, status)
converged = jnp.linalg.norm(g_kp1, ord=norm) < gtol
# TODO(jakevdp): use a fixed-point procedure rather than type-casting?
state = state._replace(
converged=converged,
failed=(status > 0) & (~converged),
k=state.k + 1,
nfev=state.nfev + ls_results.nfev,
ngev=state.ngev + ls_results.ngev,
x_k=x_kp1.astype(state.x_k.dtype),
f_k=f_kp1.astype(state.f_k.dtype),
g_k=g_kp1.astype(state.g_k.dtype),
s_history=_update_history_vectors(history=state.s_history, new=s_k),
y_history=_update_history_vectors(history=state.y_history, new=y_k),
rho_history=_update_history_scalars(history=state.rho_history, new=rho_k),
gamma=gamma,
status=jnp.where(converged, 0, status),
ls_status=ls_results.status,
)
return state
return lax.while_loop(cond_fun, body_fun, state_initial)
def _two_loop_recursion(state: LBFGSResults):
dtype = state.rho_history.dtype
his_size = len(state.rho_history)
curr_size = jnp.where(state.k < his_size, state.k, his_size)
q = -jnp.conj(state.g_k)
a_his = jnp.zeros_like(state.rho_history)
def body_fun1(j, carry):
i = his_size - 1 - j
_q, _a_his = carry
a_i = state.rho_history[i] * _dot(jnp.conj(state.s_history[i]), _q).real.astype(dtype)
_a_his = _a_his.at[i].set(a_i)
_q = _q - a_i * jnp.conj(state.y_history[i])
return _q, _a_his
q, a_his = lax.fori_loop(0, curr_size, body_fun1, (q, a_his))
q = state.gamma * q
def body_fun2(j, _q):
i = his_size - curr_size + j
b_i = state.rho_history[i] * _dot(state.y_history[i], _q).real.astype(dtype)
_q = _q + (a_his[i] - b_i) * state.s_history[i]
return _q
q = lax.fori_loop(0, curr_size, body_fun2, q)
return q
def _update_history_vectors(history, new):
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
return jnp.roll(history, -1, axis=0).at[-1, :].set(new)
def _update_history_scalars(history, new):
# TODO(Jakob-Unfried) use rolling buffer instead? See #6053
return jnp.roll(history, -1, axis=0).at[-1].set(new)