-
Notifications
You must be signed in to change notification settings - Fork 2
/
steplen.py
311 lines (251 loc) · 10.3 KB
/
steplen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
# Copyright 2014-2016 The ODL development group
#
# This file is part of ODL.
#
# ODL is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ODL is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ODL. If not, see <http://www.gnu.org/licenses/>.
"""Step length computation for optimization schemes."""
# Imports for common Python 2/3 codebase
from __future__ import print_function, division, absolute_import
from future import standard_library
standard_library.install_aliases()
from abc import ABCMeta, abstractmethod
import numpy as np
import warnings
from odl.util.utility import with_metaclass
__all__ = ('LineSearch', 'BacktrackingLineSearch', 'ConstantLineSearch',
'LineSearchFromIterNum')
class LineSearch(with_metaclass(ABCMeta, object)):
"""Abstract base class for line search step length methods."""
@abstractmethod
def __call__(self, x, direction, dir_derivative):
"""Calculate step length in direction.
Parameters
----------
x : `LinearSpaceElement`
The current point
direction : `LinearSpaceElement`
Search direction in which the line search should be computed
dir_derivative : float
Directional derivative along the ``direction``
Returns
-------
step : float
Computed step length.
"""
# Minor changes have been made in this class, compared to odl version 0.6.0.
class BacktrackingLineSearch(LineSearch):
"""Backtracking line search for step length calculation.
This methods approximately finds the longest step length fulfilling
the Armijo-Goldstein condition.
The line search algorithm is described in [BV2004]_, page 464
(`book available online
<http://stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf>`_) and
[GNS2009]_, pages 378--379. See also
`Backtracking_line_search
<https://en.wikipedia.org/wiki/Backtracking_line_search>`_.
"""
def __init__(self, function, tau=0.5, discount=0.01, alpha=1.0,
max_num_iter=None, estimate_step=False):
"""Initialize a new instance.
Parameters
----------
function : callable
The cost function of the optimization problem to be solved.
If function is not a `Functional`, the argument `dir_derivative`
tau : float, optional
The amount the step length is decreased in each iteration,
as long as it does not fulfill the decrease condition.
The step length is updated as ``step_length *= tau``.
discount : float, optional
The "discount factor" on ``step length * direction derivative``,
yielding the threshold under which the function value must lie to
be accepted (see the references).
alpha : float, optional
The initial guess for the step length.
max_num_iter : int, optional
Maximum number of iterations allowed each time the line
search method is called. If ``None``, this number is calculated
to allow a shortest step length of 10 times machine epsilon.
estimate_step : bool, optional
If the last step should be used as a estimate for the next step.
Examples
--------
Create line search
>>> r3 = odl.rn(3)
>>> func = odl.solvers.L2NormSquared(r3)
>>> line_search = BacktrackingLineSearch(func)
Find step in point x and direction d that decreases the function value.
>>> x = r3.element([1, 2, 3])
>>> d = r3.element([-1, -1, -1])
>>> step_len = line_search(x, d)
>>> step_len
1.0
>>> func(x + step_len * d) < func(x)
True
Also works with non-functionals as arguments, but then the
dir_derivative argument is mandatory
>>> r3 = odl.rn(3)
>>> func = lambda x: x[0] ** 2 + x[1] ** 2 + x[2] ** 2
>>> line_search = BacktrackingLineSearch(func)
>>> x = r3.element([1, 2, 3])
>>> d = r3.element([-1, -1, -1])
>>> dir_derivative = -12
>>> step_len = line_search(x, d, dir_derivative=dir_derivative)
>>> step_len
1.0
>>> func(x + step_len * d) < func(x)
True
"""
self.function = function
self.tau = float(tau)
self.discount = float(discount)
self.estimate_step = bool(estimate_step)
self.alpha = float(alpha)
self.total_num_iter = 0
# Use a default value that allows the shortest step to be < 10 times
# machine epsilon.
if max_num_iter is None:
try:
dtype = self.function.domain.dtype
except AttributeError:
dtype = float
eps = 10 * np.finfo(dtype).resolution
self.max_num_iter = int(np.ceil(np.log(eps) / np.log(self.tau)))
else:
self.max_num_iter = int(max_num_iter)
def __call__(self, x, direction, dir_derivative=None):
"""Calculate the optimal step length along a line.
Parameters
----------
x : `LinearSpaceElement`
The current point
direction : `LinearSpaceElement`
Search direction in which the line search should be computed
dir_derivative : float, optional
Directional derivative along the ``direction``
Default: ``function.gradient(x).inner(direction)``
Returns
-------
step : float
The computed step length
"""
fx = self.function(x)
if dir_derivative is None:
try:
gradient = self.function.gradient
except AttributeError:
raise ValueError('`dir_derivative` only optional if '
'`function.gradient exists')
else:
dir_derivative = gradient(x).inner(direction)
else:
dir_derivative = float(dir_derivative)
if dir_derivative == 0:
raise ValueError('dir_derivative == 0, no descent can be found')
if not self.estimate_step:
alpha = 1.0
else:
alpha = self.alpha
if dir_derivative > 0:
# We need to move backwards if the direction is an increase
# direction
alpha *= -1
if not np.isfinite(fx):
raise ValueError('function returned invalid value {} in starting '
'point ({})'.format(fx, x))
# Create temporary
point = x.copy()
num_iter = 0
while True:
if num_iter > self.max_num_iter:
warnings.warn('number of iterations exceeded maximum: {}, '
'step length: {}, without finding a '
'sufficient decrease'
''.format(self.max_num_iter, alpha))
break
point.lincomb(1, x, alpha, direction) # pt = x + alpha * direction
fval = self.function(point)
if np.isnan(fval):
num_iter += 1
alpha *= self.tau
warnings.warn('function returned NaN in point '
'point ({})'.format(point))
continue
# We do not want to compare against NaN below, and NaN should
# indicate a user error.
# raise ValueError('function returned NaN in point '
# 'point ({})'.format(point))
expected_decrease = np.abs(alpha * dir_derivative * self.discount)
if (fval <= fx - expected_decrease):
# Stop iterating if the value decreases sufficiently.
break
num_iter += 1
alpha *= self.tau
if not fval < fx:
warnings.warn('the step has not lead to a decrease in function '
'value: fxnew = {} and fx = {}'.format(fval, fx))
self.total_num_iter += num_iter
self.alpha = np.abs(alpha) # Store magnitude
return alpha
class ConstantLineSearch(LineSearch):
"""Line search object that returns a constant step length."""
def __init__(self, constant):
"""Initialize a new instance.
Parameters
----------
constant : float
The constant step length
"""
self.constant = float(constant)
def __call__(self, x, direction, dir_derivative):
"""Calculate the step length at a point.
All arguments are ignored and are only added to fit the interface.
"""
return self.constant
class LineSearchFromIterNum(LineSearch):
"""Line search object that returns a step length from a function.
The returned step length is ``func(iter_count)``.
"""
def __init__(self, func):
"""Initialize a new instance.
Parameters
----------
func : callable
Function that when called with an iteration count should return the
step length. The iteration count starts at 0.
Examples
--------
Make a step size that is 1.0 for the first 5 iterations, then 0.1:
>>> def step_length(iter):
... if iter < 5:
... return 1.0
... else:
... return 0.1
>>> line_search = LineSearchFromIterNum(step_length)
"""
if not callable(func):
raise TypeError('`func` must be a callable.')
self.func = func
self.iter_count = 0
def __call__(self, x, direction, dir_derivative):
"""Calculate the step length at a point.
All arguments are ignored and are only added to fit the interface.
"""
step = self.func(self.iter_count)
self.iter_count += 1
return step
if __name__ == '__main__':
# pylint: disable=wrong-import-position
from odl.util.testutils import run_doctests
run_doctests()