Skip to content

Commit

Permalink
ENH: optimize.elementwise.bracket_root: add array API support (scip…
Browse files Browse the repository at this point in the history
…y#21920)

* ENH: `optimize.elementwise.bracket_root`: add Array API support

---------

Co-authored-by: Matt Haberland <mhaberla@calpoly.edu>
  • Loading branch information
j-bowhay and mdhaber authored Nov 26, 2024
1 parent 01545db commit a58579e
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 126 deletions.
1 change: 1 addition & 0 deletions .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ env:
-t scipy.differentiate.tests.test_differentiate
-t scipy.integrate.tests.test_tanhsinh
-t scipy.integrate.tests.test_cubature
-t scipy.optimize.tests.test_bracket
-t scipy.optimize.tests.test_chandrupatla
-t scipy.optimize.tests.test_optimize
-t scipy.stats
Expand Down
103 changes: 56 additions & 47 deletions scipy/optimize/_bracket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import scipy._lib._elementwise_iterative_method as eim
from scipy._lib._util import _RichResult
from scipy._lib._array_api import array_namespace, xp_ravel

_ELIMITS = -1 # used in _bracket_root
_ESTOPONESIDE = 2 # used in _bracket_root
Expand All @@ -13,40 +14,47 @@ def _bracket_root_iv(func, xl0, xr0, xmin, xmax, factor, args, maxiter):
if not np.iterable(args):
args = (args,)

xl0 = np.asarray(xl0)[()]
if not np.issubdtype(xl0.dtype, np.number) or np.iscomplex(xl0).any():
xp = array_namespace(xl0)
xl0 = xp.asarray(xl0)[()]
if (not xp.isdtype(xl0.dtype, "numeric")
or xp.isdtype(xl0.dtype, "complex floating")):
raise ValueError('`xl0` must be numeric and real.')

xr0 = xl0 + 1 if xr0 is None else xr0
xmin = -np.inf if xmin is None else xmin
xmax = np.inf if xmax is None else xmax
xmin = -xp.inf if xmin is None else xmin
xmax = xp.inf if xmax is None else xmax
factor = 2. if factor is None else factor
xl0, xr0, xmin, xmax, factor = np.broadcast_arrays(xl0, xr0, xmin, xmax, factor)
xl0, xr0, xmin, xmax, factor = xp.broadcast_arrays(
xl0, xp.asarray(xr0), xp.asarray(xmin), xp.asarray(xmax), xp.asarray(factor))

if not np.issubdtype(xr0.dtype, np.number) or np.iscomplex(xr0).any():
if (not xp.isdtype(xr0.dtype, "numeric")
or xp.isdtype(xr0.dtype, "complex floating")):
raise ValueError('`xr0` must be numeric and real.')

if not np.issubdtype(xmin.dtype, np.number) or np.iscomplex(xmin).any():
if (not xp.isdtype(xmin.dtype, "numeric")
or xp.isdtype(xmin.dtype, "complex floating")):
raise ValueError('`xmin` must be numeric and real.')

if not np.issubdtype(xmax.dtype, np.number) or np.iscomplex(xmax).any():
if (not xp.isdtype(xmax.dtype, "numeric")
or xp.isdtype(xmax.dtype, "complex floating")):
raise ValueError('`xmax` must be numeric and real.')

if not np.issubdtype(factor.dtype, np.number) or np.iscomplex(factor).any():
if (not xp.isdtype(factor.dtype, "numeric")
or xp.isdtype(factor.dtype, "complex floating")):
raise ValueError('`factor` must be numeric and real.')
if not np.all(factor > 1):
if not xp.all(factor > 1):
raise ValueError('All elements of `factor` must be greater than 1.')

maxiter = np.asarray(maxiter)
maxiter = xp.asarray(maxiter)
message = '`maxiter` must be a non-negative integer.'
if (not np.issubdtype(maxiter.dtype, np.number) or maxiter.shape != tuple()
or np.iscomplex(maxiter)):
if (not xp.isdtype(maxiter.dtype, "numeric") or maxiter.shape != tuple()
or xp.isdtype(maxiter.dtype, "complex floating")):
raise ValueError(message)
maxiter_int = int(maxiter[()])
if not maxiter == maxiter_int or maxiter < 0:
raise ValueError(message)

return func, xl0, xr0, xmin, xmax, factor, args, maxiter
return func, xl0, xr0, xmin, xmax, factor, args, maxiter, xp


def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None,
Expand Down Expand Up @@ -152,43 +160,44 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None,

callback = None # works; I just don't want to test it
temp = _bracket_root_iv(func, xl0, xr0, xmin, xmax, factor, args, maxiter)
func, xl0, xr0, xmin, xmax, factor, args, maxiter = temp
func, xl0, xr0, xmin, xmax, factor, args, maxiter, xp = temp

xs = (xl0, xr0)
temp = eim._initialize(func, xs, args)
func, xs, fs, args, shape, dtype, xp = temp # line split for PEP8
xl0, xr0 = xs
xmin = np.broadcast_to(xmin, shape).astype(dtype, copy=False).ravel()
xmax = np.broadcast_to(xmax, shape).astype(dtype, copy=False).ravel()
xmin = xp_ravel(xp.astype(xp.broadcast_to(xmin, shape), dtype, copy=False), xp=xp)
xmax = xp_ravel(xp.astype(xp.broadcast_to(xmax, shape), dtype, copy=False), xp=xp)
invalid_bracket = ~((xmin <= xl0) & (xl0 < xr0) & (xr0 <= xmax))

# The approach is to treat the left and right searches as though they were
# (almost) totally independent one-sided bracket searches. (The interaction
# is considered when checking for termination and preparing the result
# object.)
# `x` is the "moving" end of the bracket
x = np.concatenate(xs)
f = np.concatenate(fs)
invalid_bracket = np.concatenate((invalid_bracket, invalid_bracket))
n = len(x) // 2
x = xp.concat(xs)
f = xp.concat(fs)
invalid_bracket = xp.concat((invalid_bracket, invalid_bracket))
n = x.shape[0] // 2

# `x_last` is the previous location of the moving end of the bracket. If
# the signs of `f` and `f_last` are different, `x` and `x_last` form a
# bracket.
x_last = np.concatenate((x[n:], x[:n]))
f_last = np.concatenate((f[n:], f[:n]))
x_last = xp.concat((x[n:], x[:n]))
f_last = xp.concat((f[n:], f[:n]))
# `x0` is the "fixed" end of the bracket.
x0 = x_last
# We don't need to retain the corresponding function value, since the
# fixed end of the bracket is only needed to compute the new value of the
# moving end; it is never returned.
limit = np.concatenate((xmin, xmax))
limit = xp.concat((xmin, xmax))

factor = np.broadcast_to(factor, shape).astype(dtype, copy=False).ravel()
factor = np.concatenate((factor, factor))
factor = xp_ravel(xp.broadcast_to(factor, shape), xp=xp)
factor = xp.astype(factor, dtype, copy=False)
factor = xp.concat((factor, factor))

active = np.arange(2*n)
args = [np.concatenate((arg, arg)) for arg in args]
active = xp.arange(2*n)
args = [xp.concat((arg, arg)) for arg in args]

# This is needed due to inner workings of `eim._loop`.
# We're abusing it a tiny bit.
Expand All @@ -199,32 +208,32 @@ def _bracket_root(func, xl0, xr0=None, *, xmin=None, xmax=None, factor=None,
# bracket `x0` and the moving end `x` will grow by `factor` each iteration.
# For searches with a limit, the distance between the `limit` and moving
# end of the bracket `x` will shrink by `factor` each iteration.
i = np.isinf(limit)
i = xp.isinf(limit)
ni = ~i
d = np.zeros_like(x)
d = xp.zeros_like(x)
d[i] = x[i] - x0[i]
d[ni] = limit[ni] - x[ni]

status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress
status = xp.full_like(x, eim._EINPROGRESS, dtype=xp.int32) # in progress
status[invalid_bracket] = eim._EINPUTERR
nit, nfev = 0, 1 # one function evaluation per side performed above

work = _RichResult(x=x, x0=x0, f=f, limit=limit, factor=factor,
active=active, d=d, x_last=x_last, f_last=f_last,
nit=nit, nfev=nfev, status=status, args=args,
xl=None, xr=None, fl=None, fr=None, n=n)
xl=xp.nan, xr=xp.nan, fl=xp.nan, fr=xp.nan, n=n)
res_work_pairs = [('status', 'status'), ('xl', 'xl'), ('xr', 'xr'),
('nit', 'nit'), ('nfev', 'nfev'), ('fl', 'fl'),
('fr', 'fr'), ('x', 'x'), ('f', 'f'),
('x_last', 'x_last'), ('f_last', 'f_last')]

def pre_func_eval(work):
# Initialize moving end of bracket
x = np.zeros_like(work.x)
x = xp.zeros_like(work.x)

# Unlimited brackets grow by `factor` by increasing distance from fixed
# end to moving end.
i = np.isinf(work.limit) # indices of unlimited brackets
i = xp.isinf(work.limit) # indices of unlimited brackets
work.d[i] *= work.factor[i]
x[i] = work.x0[i] + work.d[i]

Expand All @@ -250,8 +259,8 @@ def check_termination(work):
stop = (work.status == eim._EINPUTERR)

# Condition 1: a valid bracket (or the root itself) has been found
sf = np.sign(work.f)
sf_last = np.sign(work.f_last)
sf = xp.sign(work.f)
sf_last = xp.sign(work.f_last)
i = ((sf_last == -sf) | (sf_last == 0) | (sf == 0)) & ~stop
work.status[i] = eim._ECONVERGED
stop[i] = True
Expand All @@ -275,14 +284,14 @@ def check_termination(work):
# Check whether they are still active.
# To start, we need to find out where in `work.active` they would
# appear if they are indeed there.
j = np.searchsorted(work.active, also_stop)
j = xp.searchsorted(work.active, also_stop)
# If the location exceeds the length of the `work.active`, they are
# not there.
j = j[j < len(work.active)]
j = j[j < work.active.shape[0]]
# Check whether they are still there.
j = j[also_stop == work.active[j]]
# Now convert these to boolean indices to use with `work.status`.
i = np.zeros_like(stop)
i = xp.zeros_like(stop)
i[j] = True # boolean indices of elements that can also stop
i = i & ~stop
work.status[i] = _ESTOPONESIDE
Expand All @@ -294,7 +303,7 @@ def check_termination(work):
stop[i] = True

# Condition 4: non-finite value encountered
i = ~(np.isfinite(work.x) & np.isfinite(work.f)) & ~stop
i = ~(xp.isfinite(work.x) & xp.isfinite(work.f)) & ~stop
work.status[i] = eim._EVALUEERR
stop[i] = True

Expand All @@ -304,7 +313,7 @@ def post_termination_check(work):
pass

def customize_result(res, shape):
n = len(res['x']) // 2
n = res['x'].shape[0] // 2

# To avoid ambiguity, below we refer to `xl0`, the initial left endpoint
# as `a` and `xr0`, the initial right endpoint, as `b`.
Expand Down Expand Up @@ -334,10 +343,10 @@ def customize_result(res, shape):
# has been evaluated. This gives the user some information about what
# interval of the real line has been searched and shows that there is
# no sign change between the two ends.
xl = xal.copy()
fl = fal.copy()
xr = xbr.copy()
fr = fbr.copy()
xl = xp.asarray(xal, copy=True)
fl = xp.asarray(fal, copy=True)
xr = xp.asarray(xbr, copy=True)
fr = xp.asarray(fbr, copy=True)

# `status` indicates whether the bracket is valid or not. If so,
# we want to adjust the bracket we return to be the narrowest possible
Expand Down Expand Up @@ -365,11 +374,11 @@ def customize_result(res, shape):
res['fl'] = fl
res['fr'] = fr

res['nit'] = np.maximum(res['nit'][:n], res['nit'][n:])
res['nit'] = xp.maximum(res['nit'][:n], res['nit'][n:])
res['nfev'] = res['nfev'][:n] + res['nfev'][n:]
# If the status on one side is zero, the status is zero. In any case,
# report the status from one side only.
res['status'] = np.choose(sa == 0, (sb, sa))
res['status'] = xp.where(sa == 0, sa, sb)
res['success'] = (res['status'] == 0)

del res['x']
Expand Down
Loading

0 comments on commit a58579e

Please sign in to comment.