From cde2f7b1bf168c972d37215c6533133d830c5920 Mon Sep 17 00:00:00 2001 From: QBatista Date: Mon, 10 Dec 2018 11:02:26 +0400 Subject: [PATCH] Add error for invalid inputs --- quantecon/optimize/scalar_maximization.py | 12 +++++++-- quantecon/optimize/tests/test_scalar_max.py | 30 ++++++++++++++++----- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/quantecon/optimize/scalar_maximization.py b/quantecon/optimize/scalar_maximization.py index e3d959f9d..56144b330 100644 --- a/quantecon/optimize/scalar_maximization.py +++ b/quantecon/optimize/scalar_maximization.py @@ -34,7 +34,7 @@ def brent_max(func, a, b, args=(), xtol=1e-5, maxiter=500): info : tuple A tuple of the form (status_flag, num_iter). Here status_flag indicates whether or not the maximum number of function calls was - attained. A value of 0 implies that the maximum was not hit. + attained. A value of 0 implies that the maximum was not hit. The value `num_iter` is the number of function calls. Example @@ -49,7 +49,15 @@ def f(x): ``` """ - + if not np.isfinite(a): + raise ValueError("a must be finite.") + + if not np.isfinite(b): + raise ValueError("b must be finite.") + + if not a < b: + raise ValueError("a must be less than b.") + maxfun = maxiter status_flag = 0 diff --git a/quantecon/optimize/tests/test_scalar_max.py b/quantecon/optimize/tests/test_scalar_max.py index b05c00113..f4704f92f 100644 --- a/quantecon/optimize/tests/test_scalar_max.py +++ b/quantecon/optimize/tests/test_scalar_max.py @@ -4,10 +4,12 @@ """ import numpy as np from numpy.testing import assert_almost_equal +from nose.tools import raises from numba import njit from quantecon.optimize import brent_max + @njit def f(x): """ @@ -15,9 +17,10 @@ def f(x): """ return -(x + 2.0)**2 + 1.0 + def test_brent_max(): """ - Uses the function f defined above to test the scalar maximization + Uses the function f defined above to test the scalar maximization routine. """ true_fval = 1.0 @@ -25,17 +28,19 @@ def test_brent_max(): xf, fval, info = brent_max(f, -2, 2) assert_almost_equal(true_fval, fval, decimal=4) assert_almost_equal(true_xf, xf, decimal=4) - + + @njit def g(x, y): """ A multivariate function for testing on. """ return -x**2 + y - + + def test_brent_max(): """ - Uses the function f defined above to test the scalar maximization + Uses the function f defined above to test the scalar maximization routine. """ y = 5 @@ -46,6 +51,21 @@ def test_brent_max(): assert_almost_equal(true_xf, xf, decimal=4) +@raises(ValueError) +def test_invalid_a_brent_max(): + brent_max(f, -np.inf, 2) + + +@raises(ValueError) +def test_invalid_b_brent_max(): + brent_max(f, -2, np.inf) + + +@raises(ValueError) +def test_invalid_a_b_brent_max(): + brent_max(f, 1, 0) + + if __name__ == '__main__': import sys import nose @@ -54,5 +74,3 @@ def test_brent_max(): argv.append('--verbose') argv.append('--nocapture') nose.main(argv=argv, defaultTest=__file__) - -