diff --git a/.github/workflows/array_api.yml b/.github/workflows/array_api.yml index 984d6b7b0f46..da7c57615c98 100644 --- a/.github/workflows/array_api.yml +++ b/.github/workflows/array_api.yml @@ -26,6 +26,7 @@ env: -t scipy.differentiate.tests.test_differentiate -t scipy.integrate.tests.test_tanhsinh -t scipy.integrate.tests.test_cubature + -t scipy.optimize.tests.test_bracket -t scipy.optimize.tests.test_chandrupatla -t scipy.optimize.tests.test_optimize -t scipy.stats diff --git a/scipy/optimize/_bracket.py b/scipy/optimize/_bracket.py index 5660f6c07968..c312e7fd1b17 100644 --- a/scipy/optimize/_bracket.py +++ b/scipy/optimize/_bracket.py @@ -1,6 +1,7 @@ import numpy as np import scipy._lib._elementwise_iterative_method as eim from scipy._lib._util import _RichResult +from scipy._lib._array_api import array_namespace, xp_ravel _ELIMITS = -1 # used in _bracket_root _ESTOPONESIDE = 2 # used in _bracket_root @@ -13,40 +14,47 @@ def _bracket_root_iv(func, xl0, xr0, xmin, xmax, factor, args, maxiter): if not np.iterable(args): args = (args,) - xl0 = np.asarray(xl0)[()] - if not np.issubdtype(xl0.dtype, np.number) or np.iscomplex(xl0).any(): + xp = array_namespace(xl0) + xl0 = xp.asarray(xl0)[()] + if (not xp.isdtype(xl0.dtype, "numeric") + or xp.isdtype(xl0.dtype, "complex floating")): raise ValueError('`xl0` must be numeric and real.') xr0 = xl0 + 1 if xr0 is None else xr0 - xmin = -np.inf if xmin is None else xmin - xmax = np.inf if xmax is None else xmax + xmin = -xp.inf if xmin is None else xmin + xmax = xp.inf if xmax is None else xmax factor = 2. if factor is None else factor - xl0, xr0, xmin, xmax, factor = np.broadcast_arrays(xl0, xr0, xmin, xmax, factor) + xl0, xr0, xmin, xmax, factor = xp.broadcast_arrays( + xl0, xp.asarray(xr0), xp.asarray(xmin), xp.asarray(xmax), xp.asarray(factor)) - if not np.issubdtype(xr0.dtype, np.number) or np.iscomplex(xr0).any(): + if (not xp.isdtype(xr0.dtype, "numeric") + or xp.isdtype(xr0.dtype, "complex floating")): raise ValueError('`xr0` must be numeric and real.') - if not np.issubdtype(xmin.dtype, np.number) or np.iscomplex(xmin).any(): + if (not xp.isdtype(xmin.dtype, "numeric") + or xp.isdtype(xmin.dtype, "complex floating")): raise ValueError('`xmin` must be numeric and real.') - if not np.issubdtype(xmax.dtype, np.number) or np.iscomplex(xmax).any(): + if (not xp.isdtype(xmax.dtype, "numeric") + or xp.isdtype(xmax.dtype, "complex floating")): raise ValueError('`xmax` must be numeric and real.') - if not np.issubdtype(factor.dtype, np.number) or np.iscomplex(factor).any(): + if (not xp.isdtype(factor.dtype, "numeric") + or xp.isdtype(factor.dtype, "complex floating")): raise ValueError('`factor` must be numeric and real.') - if not np.all(factor > 1): + if not xp.all(factor > 1): raise ValueError('All elements of `factor` must be greater than 1.') - maxiter = np.asarray(maxiter) + maxiter = xp.asarray(maxiter) message = '`maxiter` must be a non-negative integer.' - if (not np.issubdtype(maxiter.dtype, np.number) or maxiter.shape != tuple() - or np.iscomplex(maxiter)): + if (not xp.isdtype(maxiter.dtype, "numeric") or maxiter.shape != tuple() + or xp.isdtype(maxiter.dtype, "complex floating")): raise ValueError(message) maxiter_int = int(maxiter[()]) if not maxiter == maxiter_int or maxiter < 0: raise ValueError(message) - return func, xl0, xr0, xmin, xmax, factor, args, maxiter + return func, xl0, xr0, xmin, xmax, factor, args, maxiter, xp def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None, @@ -152,14 +160,14 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None, callback = None # works; I just don't want to test it temp = _bracket_root_iv(func, xl0, xr0, xmin, xmax, factor, args, maxiter) - func, xl0, xr0, xmin, xmax, factor, args, maxiter = temp + func, xl0, xr0, xmin, xmax, factor, args, maxiter, xp = temp xs = (xl0, xr0) temp = eim._initialize(func, xs, args) func, xs, fs, args, shape, dtype, xp = temp # line split for PEP8 xl0, xr0 = xs - xmin = np.broadcast_to(xmin, shape).astype(dtype, copy=False).ravel() - xmax = np.broadcast_to(xmax, shape).astype(dtype, copy=False).ravel() + xmin = xp_ravel(xp.astype(xp.broadcast_to(xmin, shape), dtype, copy=False), xp=xp) + xmax = xp_ravel(xp.astype(xp.broadcast_to(xmax, shape), dtype, copy=False), xp=xp) invalid_bracket = ~((xmin <= xl0) & (xl0 < xr0) & (xr0 <= xmax)) # The approach is to treat the left and right searches as though they were @@ -167,28 +175,29 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None, # is considered when checking for termination and preparing the result # object.) # `x` is the "moving" end of the bracket - x = np.concatenate(xs) - f = np.concatenate(fs) - invalid_bracket = np.concatenate((invalid_bracket, invalid_bracket)) - n = len(x) // 2 + x = xp.concat(xs) + f = xp.concat(fs) + invalid_bracket = xp.concat((invalid_bracket, invalid_bracket)) + n = x.shape[0] // 2 # `x_last` is the previous location of the moving end of the bracket. If # the signs of `f` and `f_last` are different, `x` and `x_last` form a # bracket. - x_last = np.concatenate((x[n:], x[:n])) - f_last = np.concatenate((f[n:], f[:n])) + x_last = xp.concat((x[n:], x[:n])) + f_last = xp.concat((f[n:], f[:n])) # `x0` is the "fixed" end of the bracket. x0 = x_last # We don't need to retain the corresponding function value, since the # fixed end of the bracket is only needed to compute the new value of the # moving end; it is never returned. - limit = np.concatenate((xmin, xmax)) + limit = xp.concat((xmin, xmax)) - factor = np.broadcast_to(factor, shape).astype(dtype, copy=False).ravel() - factor = np.concatenate((factor, factor)) + factor = xp_ravel(xp.broadcast_to(factor, shape), xp=xp) + factor = xp.astype(factor, dtype, copy=False) + factor = xp.concat((factor, factor)) - active = np.arange(2*n) - args = [np.concatenate((arg, arg)) for arg in args] + active = xp.arange(2*n) + args = [xp.concat((arg, arg)) for arg in args] # This is needed due to inner workings of `eim._loop`. # We're abusing it a tiny bit. @@ -199,20 +208,20 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None, # bracket `x0` and the moving end `x` will grow by `factor` each iteration. # For searches with a limit, the distance between the `limit` and moving # end of the bracket `x` will shrink by `factor` each iteration. - i = np.isinf(limit) + i = xp.isinf(limit) ni = ~i - d = np.zeros_like(x) + d = xp.zeros_like(x) d[i] = x[i] - x0[i] d[ni] = limit[ni] - x[ni] - status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress + status = xp.full_like(x, eim._EINPROGRESS, dtype=xp.int32) # in progress status[invalid_bracket] = eim._EINPUTERR nit, nfev = 0, 1 # one function evaluation per side performed above work = _RichResult(x=x, x0=x0, f=f, limit=limit, factor=factor, active=active, d=d, x_last=x_last, f_last=f_last, nit=nit, nfev=nfev, status=status, args=args, - xl=None, xr=None, fl=None, fr=None, n=n) + xl=xp.nan, xr=xp.nan, fl=xp.nan, fr=xp.nan, n=n) res_work_pairs = [('status', 'status'), ('xl', 'xl'), ('xr', 'xr'), ('nit', 'nit'), ('nfev', 'nfev'), ('fl', 'fl'), ('fr', 'fr'), ('x', 'x'), ('f', 'f'), @@ -220,11 +229,11 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None, def pre_func_eval(work): # Initialize moving end of bracket - x = np.zeros_like(work.x) + x = xp.zeros_like(work.x) # Unlimited brackets grow by `factor` by increasing distance from fixed # end to moving end. - i = np.isinf(work.limit) # indices of unlimited brackets + i = xp.isinf(work.limit) # indices of unlimited brackets work.d[i] *= work.factor[i] x[i] = work.x0[i] + work.d[i] @@ -250,8 +259,8 @@ def check_termination(work): stop = (work.status == eim._EINPUTERR) # Condition 1: a valid bracket (or the root itself) has been found - sf = np.sign(work.f) - sf_last = np.sign(work.f_last) + sf = xp.sign(work.f) + sf_last = xp.sign(work.f_last) i = ((sf_last == -sf) | (sf_last == 0) | (sf == 0)) & ~stop work.status[i] = eim._ECONVERGED stop[i] = True @@ -275,14 +284,14 @@ def check_termination(work): # Check whether they are still active. # To start, we need to find out where in `work.active` they would # appear if they are indeed there. - j = np.searchsorted(work.active, also_stop) + j = xp.searchsorted(work.active, also_stop) # If the location exceeds the length of the `work.active`, they are # not there. - j = j[j < len(work.active)] + j = j[j < work.active.shape[0]] # Check whether they are still there. j = j[also_stop == work.active[j]] # Now convert these to boolean indices to use with `work.status`. - i = np.zeros_like(stop) + i = xp.zeros_like(stop) i[j] = True # boolean indices of elements that can also stop i = i & ~stop work.status[i] = _ESTOPONESIDE @@ -294,7 +303,7 @@ def check_termination(work): stop[i] = True # Condition 4: non-finite value encountered - i = ~(np.isfinite(work.x) & np.isfinite(work.f)) & ~stop + i = ~(xp.isfinite(work.x) & xp.isfinite(work.f)) & ~stop work.status[i] = eim._EVALUEERR stop[i] = True @@ -304,7 +313,7 @@ def post_termination_check(work): pass def customize_result(res, shape): - n = len(res['x']) // 2 + n = res['x'].shape[0] // 2 # To avoid ambiguity, below we refer to `xl0`, the initial left endpoint # as `a` and `xr0`, the initial right endpoint, as `b`. @@ -334,10 +343,10 @@ def customize_result(res, shape): # has been evaluated. This gives the user some information about what # interval of the real line has been searched and shows that there is # no sign change between the two ends. - xl = xal.copy() - fl = fal.copy() - xr = xbr.copy() - fr = fbr.copy() + xl = xp.asarray(xal, copy=True) + fl = xp.asarray(fal, copy=True) + xr = xp.asarray(xbr, copy=True) + fr = xp.asarray(fbr, copy=True) # `status` indicates whether the bracket is valid or not. If so, # we want to adjust the bracket we return to be the narrowest possible @@ -365,11 +374,11 @@ def customize_result(res, shape): res['fl'] = fl res['fr'] = fr - res['nit'] = np.maximum(res['nit'][:n], res['nit'][n:]) + res['nit'] = xp.maximum(res['nit'][:n], res['nit'][n:]) res['nfev'] = res['nfev'][:n] + res['nfev'][n:] # If the status on one side is zero, the status is zero. In any case, # report the status from one side only. - res['status'] = np.choose(sa == 0, (sb, sa)) + res['status'] = xp.where(sa == 0, sa, sb) res['success'] = (res['status'] == 0) del res['x'] diff --git a/scipy/optimize/tests/test_bracket.py b/scipy/optimize/tests/test_bracket.py index 0f2341ae39a7..293df8553628 100644 --- a/scipy/optimize/tests/test_bracket.py +++ b/scipy/optimize/tests/test_bracket.py @@ -1,12 +1,16 @@ import pytest import numpy as np -from numpy.testing import assert_array_less, assert_allclose, assert_equal +from numpy.testing import assert_allclose, assert_equal from scipy.optimize._bracket import _ELIMITS from scipy.optimize.elementwise import bracket_root, bracket_minimum import scipy._lib._elementwise_iterative_method as eim from scipy import stats +from scipy._lib._array_api_no_0d import (xp_assert_close, xp_assert_equal, + xp_assert_less, array_namespace) +from scipy._lib._array_api import xp_ravel, is_torch +from scipy.conftest import array_api_compatible # These tests were originally written for the private `optimize._bracket` @@ -38,12 +42,19 @@ def _bracket_minimum(*args, **kwargs): return res +array_api_strict_skip_reason = 'Array API does not support fancy indexing assignment.' +jax_skip_reason = 'JAX arrays do not support item assignment.' + +@pytest.mark.skip_xp_backends('array_api_strict', reason=array_api_strict_skip_reason) +@pytest.mark.skip_xp_backends('jax.numpy', reason=jax_skip_reason) +@array_api_compatible +@pytest.mark.usefixtures("skip_xp_backends") class TestBracketRoot: @pytest.mark.parametrize("seed", (615655101, 3141866013, 238075752)) @pytest.mark.parametrize("use_xmin", (False, True)) @pytest.mark.parametrize("other_side", (False, True)) @pytest.mark.parametrize("fix_one_side", (False, True)) - def test_nfev_expected(self, seed, use_xmin, other_side, fix_one_side): + def test_nfev_expected(self, seed, use_xmin, other_side, fix_one_side, xp): # Property-based test to confirm that _bracket_root is behaving as # expected. The basic case is when root < a < b. # The number of times bracket expands (per side) can be found by @@ -53,9 +64,8 @@ def test_nfev_expected(self, seed, use_xmin, other_side, fix_one_side): # into the expression for the ends of the bracket. # `other_side=True` is the case that a < b < root # Special cases like a < root < b are tested separately - rng = np.random.default_rng(seed) - xl0, d, factor = rng.random(size=3) * [1e5, 10, 5] + xl0, d, factor = xp.asarray(rng.random(size=3) * [1e5, 10, 5]) factor = 1 + factor # factor must be greater than 1 xr0 = xl0 + d # xr0 must be greater than a in basic case @@ -64,12 +74,12 @@ def f(x): return x # root is 0 if use_xmin: - xmin = -rng.random() - n = np.ceil(np.log(-(xl0 - xmin) / xmin) / np.log(factor)) + xmin = xp.asarray(-rng.random()) + n = xp.ceil(xp.log(-(xl0 - xmin) / xmin) / xp.log(factor)) l, u = xmin + (xl0 - xmin)*factor**-n, xmin + (xl0 - xmin)*factor**-(n - 1) kwargs = dict(xl0=xl0, xr0=xr0, factor=factor, xmin=xmin) else: - n = np.ceil(np.log(xr0/d) / np.log(factor)) + n = xp.ceil(xp.log(xr0/d) / xp.log(factor)) l, u = xr0 - d*factor**n, xr0 - d*factor**(n-1) kwargs = dict(xl0=xl0, xr0=xr0, factor=factor) @@ -105,33 +115,33 @@ def f(x): # Compare reported bracket to theoretical bracket and reported function # values to function evaluated at bracket. - bracket = np.asarray([res.xl, res.xr]) - assert_allclose(bracket, (l, u)) - f_bracket = np.asarray([res.fl, res.fr]) - assert_allclose(f_bracket, f(bracket)) + bracket = xp.asarray([res.xl, res.xr]) + xp_assert_close(bracket, xp.asarray([l, u])) + f_bracket = xp.asarray([res.fl, res.fr]) + xp_assert_close(f_bracket, f(bracket)) # Check that bracket is valid and that status and success are correct assert res.xr > res.xl - signs = np.sign(f_bracket) + signs = xp.sign(f_bracket) assert signs[0] == -signs[1] assert res.status == 0 assert res.success def f(self, q, p): - return stats.norm.cdf(q) - p + return stats._stats_py._SimpleNormal().cdf(q) - p @pytest.mark.parametrize('p', [0.6, np.linspace(0.05, 0.95, 10)]) @pytest.mark.parametrize('xmin', [-5, None]) @pytest.mark.parametrize('xmax', [5, None]) @pytest.mark.parametrize('factor', [1.2, 2]) - def test_basic(self, p, xmin, xmax, factor): + def test_basic(self, p, xmin, xmax, factor, xp): # Test basic functionality to bracket root (distribution PPF) - res = _bracket_root(self.f, -0.01, 0.01, xmin=xmin, xmax=xmax, - factor=factor, args=(p,)) - assert_equal(-np.sign(res.fl), np.sign(res.fr)) + res = _bracket_root(self.f, xp.asarray(-0.01), 0.01, xmin=xmin, xmax=xmax, + factor=factor, args=(xp.asarray(p),)) + xp_assert_equal(-xp.sign(res.fl), xp.sign(res.fr)) @pytest.mark.parametrize('shape', [tuple(), (12,), (3, 4), (3, 2, 2)]) - def test_vectorization(self, shape): + def test_vectorization(self, shape, xp): # Test for correct functionality, output shapes, and dtypes for various # input shapes. p = np.linspace(-0.05, 1.05, 12).reshape(shape) if shape else 0.6 @@ -157,76 +167,85 @@ def f(*args, **kwargs): i = rng.random(size=shape) > 0.5 xmin[i], xmax[i] = -np.inf, np.inf factor = rng.random(size=shape) + 1.5 + refs = bracket_root_single(xl0, xr0, xmin, xmax, factor, p).ravel() + xl0, xr0, xmin, xmax, factor = (xp.asarray(xl0), xp.asarray(xr0), + xp.asarray(xmin), xp.asarray(xmax), + xp.asarray(factor)) + args = tuple(map(xp.asarray, args)) res = _bracket_root(f, xl0, xr0, xmin=xmin, xmax=xmax, factor=factor, args=args, maxiter=maxiter) - refs = bracket_root_single(xl0, xr0, xmin, xmax, factor, p).ravel() attrs = ['xl', 'xr', 'fl', 'fr', 'success', 'nfev', 'nit'] for attr in attrs: - ref_attr = [getattr(ref, attr) for ref in refs] + ref_attr = [xp.asarray(getattr(ref, attr)) for ref in refs] res_attr = getattr(res, attr) - assert_allclose(res_attr.ravel(), ref_attr) - assert_equal(res_attr.shape, shape) + rtol = 5e-7 if is_torch(xp) else None # consider looking into this + xp_assert_close(xp_ravel(res_attr, xp=xp), xp.stack(ref_attr), rtol=rtol) + xp_assert_equal(res_attr.shape, shape) - assert np.issubdtype(res.success.dtype, np.bool_) + xp_test = array_namespace(xp.asarray(1.)) + assert res.success.dtype == xp_test.bool if shape: - assert np.all(res.success[1:-1]) - assert np.issubdtype(res.status.dtype, np.integer) - assert np.issubdtype(res.nfev.dtype, np.integer) - assert np.issubdtype(res.nit.dtype, np.integer) - assert_equal(np.max(res.nit), f.f_evals - 2) - assert_array_less(res.xl, res.xr) - assert_allclose(res.fl, self.f(res.xl, *args)) - assert_allclose(res.fr, self.f(res.xr, *args)) - - def test_flags(self): + assert xp.all(res.success[1:-1]) + assert res.status.dtype == xp.int32 + assert res.nfev.dtype == xp.int32 + assert res.nit.dtype == xp.int32 + assert xp.max(res.nit) == f.f_evals - 2 + xp_assert_less(res.xl, res.xr) + xp_assert_close(res.fl, xp.asarray(self.f(res.xl, *args))) + xp_assert_close(res.fr, xp.asarray(self.f(res.xr, *args))) + + def test_flags(self, xp): # Test cases that should produce different status flags; show that all # can be produced simultaneously. def f(xs, js): funcs = [lambda x: x - 1.5, lambda x: x - 1000, lambda x: x - 1000, - lambda x: np.nan, + lambda x: x * xp.nan, lambda x: x] - return [funcs[j](x) for x, j in zip(xs, js)] + return [funcs[int(j)](x) for x, j in zip(xs, js)] - args = (np.arange(5, dtype=np.int64),) + args = (xp.arange(5, dtype=xp.int64),) res = _bracket_root(f, - xl0=[-1, -1, -1, -1, 4], - xr0=[1, 1, 1, 1, -4], - xmin=[-np.inf, -1, -np.inf, -np.inf, 6], - xmax=[np.inf, 1, np.inf, np.inf, 2], + xl0=xp.asarray([-1., -1., -1., -1., 4.]), + xr0=xp.asarray([1, 1, 1, 1, -4]), + xmin=xp.asarray([-xp.inf, -1, -xp.inf, -xp.inf, 6]), + xmax=xp.asarray([xp.inf, 1, xp.inf, xp.inf, 2]), args=args, maxiter=3) - ref_flags = np.array([eim._ECONVERGED, - _ELIMITS, - eim._ECONVERR, - eim._EVALUEERR, - eim._EINPUTERR]) + ref_flags = xp.asarray([eim._ECONVERGED, + _ELIMITS, + eim._ECONVERR, + eim._EVALUEERR, + eim._EINPUTERR], + dtype=xp.int32) - assert_equal(res.status, ref_flags) + xp_assert_equal(res.status, ref_flags) @pytest.mark.parametrize("root", (0.622, [0.622, 0.623])) @pytest.mark.parametrize('xmin', [-5, None]) @pytest.mark.parametrize('xmax', [5, None]) - @pytest.mark.parametrize("dtype", (np.float16, np.float32, np.float64)) - def test_dtype(self, root, xmin, xmax, dtype): + @pytest.mark.parametrize("dtype", ("float16", "float32", "float64")) + def test_dtype(self, root, xmin, xmax, dtype, xp): # Test that dtypes are preserved + dtype = getattr(xp, dtype) + xp_test = array_namespace(xp.asarray(1.)) - xmin = xmin if xmin is None else dtype(xmin) - xmax = xmax if xmax is None else dtype(xmax) - root = dtype(root) + xmin = xmin if xmin is None else xp.asarray(xmin, dtype=dtype) + xmax = xmax if xmax is None else xp.asarray(xmax, dtype=dtype) + root = xp.asarray(root, dtype=dtype) def f(x, root): - return ((x - root) ** 3).astype(dtype) + return xp_test.astype((x - root) ** 3, dtype) - bracket = np.asarray([-0.01, 0.01], dtype=dtype) + bracket = xp.asarray([-0.01, 0.01], dtype=dtype) res = _bracket_root(f, *bracket, xmin=xmin, xmax=xmax, args=(root,)) - assert np.all(res.success) + assert xp.all(res.success) assert res.xl.dtype == res.xr.dtype == dtype assert res.fl.dtype == res.fr.dtype == dtype - def test_input_validation(self): + def test_input_validation(self, xp): # Test input validation for appropriate error messages message = '`func` must be callable.' @@ -249,10 +268,10 @@ def test_input_validation(self): with pytest.raises(ValueError, match=message): _bracket_root(lambda x: x, -4, 4, factor=0.5) - message = "shape mismatch: objects cannot be broadcast" - # raised by `np.broadcast, but the traceback is readable IMO - with pytest.raises(ValueError, match=message): - _bracket_root(lambda x: x, [-2, -3], [3, 4, 5]) + message = "broadcast" + # raised by `xp.broadcast, but the traceback is readable IMO + with pytest.raises(Exception, match=message): + _bracket_root(lambda x: x, xp.asarray([-2, -3]), xp.asarray([3, 4, 5])) # Consider making this give a more readable error message # with pytest.raises(ValueError, match=message): # _bracket_root(lambda x: [x[0], x[1], x[1]], [-3, -3], [5, 5]) @@ -265,24 +284,24 @@ def test_input_validation(self): with pytest.raises(ValueError, match=message): _bracket_root(lambda x: x, -4, 4, maxiter="shrubbery") - - def test_special_cases(self): + def test_special_cases(self, xp): # Test edge cases and other special cases + xp_test = array_namespace(xp.asarray(1.)) # Test that integers are not passed to `f` # (otherwise this would overflow) def f(x): - assert np.issubdtype(x.dtype, np.floating) + assert xp_test.isdtype(x.dtype, "real floating") return x ** 99 - 1 - res = _bracket_root(f, -7, 5) + res = _bracket_root(f, xp.asarray(-7.), xp.asarray(5.)) assert res.success # Test maxiter = 0. Should do nothing to bracket. def f(x): return x - 10 - bracket = (-3, 5) + bracket = (xp.asarray(-3.), xp.asarray(5.)) res = _bracket_root(f, *bracket, maxiter=0) assert res.xl, res.xr == bracket assert res.nit == 0 @@ -293,9 +312,10 @@ def f(x): def f(x, c): return c*x - 1 - res = _bracket_root(f, -1, 1, args=3) + res = _bracket_root(f, xp.asarray(-1.), xp.asarray(1.), + args=xp.asarray(3.)) assert res.success - assert_allclose(res.fl, f(res.xl, 3)) + xp_assert_close(res.fl, f(res.xl, 3)) # Test other edge cases @@ -305,29 +325,33 @@ def f(x): # 1. root lies within guess of bracket f.count = 0 - _bracket_root(f, -10, 20) - assert_equal(f.count, 2) + _bracket_root(f, xp.asarray(-10), xp.asarray(20)) + assert f.count == 2 # 2. bracket endpoint hits root exactly f.count = 0 - res = _bracket_root(f, 5, 10, factor=2) - bracket = (res.xl, res.xr) - assert_equal(res.nfev, 4) - assert_allclose(bracket, (0, 5), atol=1e-15) + res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.), + factor=2) + + assert res.nfev == 4 + xp_assert_close(res.xl, xp.asarray(0.), atol=1e-15) + xp_assert_close(res.xr, xp.asarray(5.), atol=1e-15) # 3. bracket limit hits root exactly with np.errstate(over='ignore'): - res = _bracket_root(f, 5, 10, xmin=0) - bracket = (res.xl, res.xr) - assert_allclose(bracket[0], 0, atol=1e-15) + res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.), + xmin=0) + xp_assert_close(res.xl, xp.asarray(0.), atol=1e-15) + with np.errstate(over='ignore'): - res = _bracket_root(f, -10, -5, xmax=0) - bracket = (res.xl, res.xr) - assert_allclose(bracket[1], 0, atol=1e-15) + res = _bracket_root(f, xp.asarray(-10.), xp.asarray(-5.), + xmax=0) + xp_assert_close(res.xr, xp.asarray(0.), atol=1e-15) # 4. bracket not within min, max with np.errstate(over='ignore'): - res = _bracket_root(f, 5, 10, xmin=1) + res = _bracket_root(f, xp.asarray(5.), xp.asarray(10.), + xmin=1) assert not res.success