Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bisection and brent's method for root finding #424

Merged
merged 3 commits into from
Jul 30, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion quantecon/optimize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""

from .scalar_maximization import brent_max
from .root_finding import newton, newton_halley, newton_secant
from .root_finding import newton, newton_halley, newton_secant, bisect, brentq
246 changes: 238 additions & 8 deletions quantecon/optimize/root_finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,31 @@
from numba import jit, njit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jit seems not to be used in this file.

from collections import namedtuple

__all__ = ['newton', 'newton_halley', 'newton_secant']
__all__ = ['newton', 'newton_halley', 'newton_secant', 'bisect', 'brentq']

_ECONVERGED = 0
_ECONVERR = -1

results = namedtuple('results',
('root function_calls iterations converged'))
_iter = 100
_xtol = 2e-12
_rtol = 4*np.finfo(float).eps

results = namedtuple('results', 'root function_calls iterations converged')


@njit
def _results(r):
r"""Select from a tuple of(root, funccalls, iterations, flag)"""
x, funcalls, iterations, flag = r
return results(x, funcalls, iterations, flag == 0)


@njit
def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
disp=True):
"""
Find a zero from the Newton-Raphson method using the jitted version of
Scipy's newton for scalars. Note that this does not provide an alternative
Find a zero from the Newton-Raphson method using the jitted version of
Scipy's newton for scalars. Note that this does not provide an alternative
method such as secant. Thus, it is important that `fprime` can be provided.

Note that `func` and `fprime` must be jitted via Numba.
Expand Down Expand Up @@ -85,18 +90,19 @@ def newton(func, x0, fprime, args=(), tol=1.48e-8, maxiter=50,
break
newton_step = fval / fder
# Newton step
p = p0 - newton_step
p = p0 - newton_step
if abs(p - p0) < tol:
status = _ECONVERGED
break
p0 = p

if disp and status == _ECONVERR:
msg = "Failed to converge"
raise RuntimeError(msg)

return _results((p, funcalls, itr + 1, status))


@njit
def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8,
maxiter=50, disp=True):
Expand Down Expand Up @@ -179,6 +185,7 @@ def newton_halley(func, x0, fprime, fprime2, args=(), tol=1.48e-8,

return _results((p, funcalls, itr + 1, status))


@njit
def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
disp=True):
Expand Down Expand Up @@ -254,4 +261,227 @@ def newton_secant(func, x0, args=(), tol=1.48e-8, maxiter=50,
msg = "Failed to converge"
raise RuntimeError(msg)

return _results((p, funcalls, itr + 1, status))
return _results((p, funcalls, itr + 1, status))


@njit
def _bisect_interval(a, b, fa, fb):
"""Conditional checks for intervals in methods involving bisection"""
if fa*fb > 0:
raise ValueError("f(a) and f(b) must have different signs")
root = 0.0
status = _ECONVERR

# Root found at either end of [a,b]
if fa == 0:
root = a
status = _ECONVERGED
if fb == 0:
root = b
status = _ECONVERGED

return root, status


@njit
def bisect(f, a, b, args=(), xtol=_xtol,
rtol=_rtol, maxiter=_iter, disp=True):
"""
Find root of a function within an interval adapted from Scipy's bisect.

Basic bisection routine to find a zero of the function `f` between the
arguments `a` and `b`. `f(a)` and `f(b)` cannot have the same signs.

`f` must be jitted via numba.

Parameters
----------
f : jitted and callable
Python function returning a number. `f` must be continuous.
a : number
One end of the bracketing interval [a,b].
b : number
The other end of the bracketing interval [a,b].
args : tuple, optional
Extra arguments to be used in the function call.
xtol : number, optional
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
parameter must be nonnegative.
rtol : number, optional
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root.
maxiter : number, optional
Maximum number of iterations.
disp : bool, optional
If True, raise a RuntimeError if the algorithm didn't converge.

Returns
-------
results : namedtuple

"""

if xtol <= 0:
raise ValueError("xtol is too small (<= 0)")

if maxiter < 1:
raise ValueError("maxiter must be greater than 0")

# Convert to float
xa = a * 1.0
xb = b * 1.0

fa = f(xa, *args)
fb = f(xb, *args)
funcalls = 2
root, status = _bisect_interval(xa, xb, fa, fb)

# Check for sign error and early termination
if status == _ECONVERGED:
itr = 0
else:
# Perform bisection
dm = xb - xa
for itr in range(maxiter):
dm *= 0.5
xm = xa + dm
fm = f(xm, *args)
funcalls += 1

if fm * fa >= 0:
xa = xm

if fm == 0 or abs(dm) < xtol + rtol * abs(xm):
root = xm
status = _ECONVERGED
itr += 1
break

if disp and status == _ECONVERR:
raise RuntimeError("Failed to converge")

return _results((root, funcalls, itr, status))


@njit
def brentq(f, a, b, args=(), xtol=_xtol,
rtol=_rtol, maxiter=_iter, disp=True):
"""
Find a root of a function in a bracketing interval using Brent's method
adapted from Scipy's brentq.

Uses the classic Brent's method to find a zero of the function `f` on
the sign changing interval [a , b].

`f` must be jitted via numba.

Parameters
----------
f : jitted and callable
Python function returning a number. `f` must be continuous.
a : number
One end of the bracketing interval [a,b].
b : number
The other end of the bracketing interval [a,b].
args : tuple, optional
Extra arguments to be used in the function call.
xtol : number, optional
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root. The
parameter must be nonnegative.
rtol : number, optional
The computed root ``x0`` will satisfy ``np.allclose(x, x0,
atol=xtol, rtol=rtol)``, where ``x`` is the exact root.
maxiter : number, optional
Maximum number of iterations.
disp : bool, optional
If True, raise a RuntimeError if the algorithm didn't converge.

Returns
-------
results : namedtuple

"""
if xtol <= 0:
raise ValueError("xtol is too small (<= 0)")
if maxiter < 1:
raise ValueError("maxiter must be greater than 0")

# Convert to float
xpre = a * 1.0
xcur = b * 1.0

fpre = f(xpre, *args)
fcur = f(xcur, *args)
funcalls = 2

root, status = _bisect_interval(xpre, xcur, fpre, fcur)

# Check for sign error and early termination
if status == _ECONVERGED:
itr = 0
else:
# Perform Brent's method
for itr in range(maxiter):

if fpre * fcur < 0:
xblk = xpre
fblk = fpre
spre = scur = xcur - xpre
if abs(fblk) < abs(fcur):
xpre = xcur
xcur = xblk
xblk = xpre

fpre = fcur
fcur = fblk
fblk = fpre

delta = (xtol + rtol * abs(xcur)) / 2
sbis = (xblk - xcur) / 2

# Root found
if fcur == 0 or abs(sbis) < delta:
status = _ECONVERGED
root = xcur
itr += 1
break

if abs(spre) > delta and abs(fcur) < abs(fpre):
if xpre == xblk:
# interpolate
stry = -fcur * (xcur - xpre) / (fcur - fpre)
else:
# extrapolate
dpre = (fpre - fcur) / (xpre - xcur)
dblk = (fblk - fcur) / (xblk - xcur)
stry = -fcur * (fblk * dblk - fpre * dpre) / \
(dblk * dpre * (fblk - fpre))

if (2 * abs(stry) < min(abs(spre), 3 * abs(sbis) - delta)):
# good short step
spre = scur
scur = stry
else:
# bisect
spre = sbis
scur = sbis
else:
# bisect
spre = sbis
scur = sbis

xpre = xcur
fpre = fcur
if (abs(scur) > delta):
xcur += scur
else:
xcur += (delta if sbis > 0 else -delta)
fcur = f(xcur, *args)
funcalls += 1

if disp and status == _ECONVERR:
raise RuntimeError("Failed to converge")

return _results((root, funcalls, itr, status))
30 changes: 27 additions & 3 deletions quantecon/optimize/tests/test_root_finding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from numpy.testing import assert_almost_equal, assert_allclose
from numba import njit

from quantecon.optimize import newton, newton_halley, newton_secant
from quantecon.optimize import *


@njit
def func(x):
Expand All @@ -19,13 +20,15 @@ def func_prime(x):
"""
return (3*x**2)


@njit
def func_prime2(x):
"""
Second order derivative for func.
"""
return 6*x


@njit
def func_two(x):
"""
Expand All @@ -41,6 +44,7 @@ def func_two_prime(x):
"""
return 4*np.cos(4*(x - 1/4)) + 20*x**19 + 1


@njit
def func_two_prime2(x):
"""
Expand All @@ -67,15 +71,16 @@ def test_newton_basic_two():
true_fval = 1.0
fval = newton(func, 5, func_prime)
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0)


def test_newton_hard():
"""
Harder test for convergence.
"""
true_fval = 0.408
fval = newton(func_two, 0.4, func_two_prime)
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)


def test_halley_basic():
"""
Expand All @@ -85,6 +90,7 @@ def test_halley_basic():
fval = newton_halley(func, 5, func_prime, func_prime2)
assert_almost_equal(true_fval, fval.root, decimal=4)


def test_halley_hard():
"""
Harder test for halley method
Expand All @@ -93,6 +99,7 @@ def test_halley_hard():
fval = newton_halley(func_two, 0.4, func_two_prime, func_two_prime2)
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)


def test_secant_basic():
"""
Basic test for secant option.
Expand All @@ -111,8 +118,25 @@ def test_secant_hard():
assert_allclose(true_fval, fval.root, rtol=1e-5, atol=0.01)


def run_check(method, name):
a = -1
b = np.sqrt(3)
true_fval = 0.408
r = method(func_two, a, b)
assert_allclose(true_fval, r.root, atol=0.01, rtol=1e-5,
err_msg='method %s' % name)


def test_bisect_basic():
run_check(bisect, 'bisect')


def test_brentq_basic():
run_check(brentq, 'brentq')

# executing testcases.


if __name__ == '__main__':
import sys
import nose
Expand Down