Skip to content

Commit

Permalink
ENH: optimize: add array api support to rosen and friends (scipy#21778
Browse files Browse the repository at this point in the history
)

* ENH: optimize: add array api support to rosen and friends
  • Loading branch information
j-bowhay authored Nov 2, 2024
1 parent 464f3c7 commit 10807c0
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 20 deletions.
1 change: 1 addition & 0 deletions .github/workflows/array_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ env:
-t scipy.integrate.tests.test_tanhsinh
-t scipy.integrate.tests.test_cubature
-t scipy.optimize.tests.test_chandrupatla
-t scipy.optimize.tests.test_optimize
-t scipy.stats
-t scipy.ndimage
Expand Down
10 changes: 10 additions & 0 deletions scipy/_lib/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
'xp_atleast_nd', 'xp_copy', 'xp_copysign', 'xp_device',
'xp_moveaxis_to_end', 'xp_ravel', 'xp_real', 'xp_sign', 'xp_size',
'xp_take_along_axis', 'xp_unsupported_param_msg', 'xp_vector_norm',
'xp_create_diagonal'
]


Expand Down Expand Up @@ -645,3 +646,12 @@ def xp_float_to_complex(arr: Array, xp: ModuleType | None = None) -> Array:
arr = xp.astype(arr, xp.complex128)

return arr

def xp_create_diagonal(x: Array, /, *, offset: int = 0,
xp: ModuleType | None = None) -> Array:
xp = array_namespace(x) if xp is None else xp
n = x.shape[0] + abs(offset)
diag = xp.zeros(n**2, dtype=x.dtype)
i = offset if offset >= 0 else abs(offset) * n
diag[i:min(n*(n-offset), diag.shape[0]):n+1] = x
return xp.reshape(diag, (n, n))
17 changes: 16 additions & 1 deletion scipy/_lib/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from scipy.conftest import array_api_compatible
from scipy._lib._array_api import (
_GLOBAL_CONFIG, array_namespace, _asarray, xp_copy, xp_assert_equal, is_numpy
_GLOBAL_CONFIG, array_namespace, _asarray, xp_copy, xp_assert_equal, is_numpy,
xp_create_diagonal
)
from scipy._lib._array_api_no_0d import xp_assert_equal as xp_assert_equal_no_0d
import scipy._lib.array_api_compat.numpy as np_compat
Expand Down Expand Up @@ -185,3 +186,17 @@ def test_check_scalar_no_0d(self, xp):
# scalars-vs-0d passes (if values match) also with regular python objects
xp_assert_equal_no_0d(0., xp.asarray(0.))
xp_assert_equal_no_0d(42, xp.asarray(42))

@skip_xp_backends('jax.numpy',
reasons=["JAX arrays do not support item assignment"])
@pytest.mark.usefixtures("skip_xp_backends")
@array_api_compatible
@pytest.mark.parametrize('n', range(1, 10))
@pytest.mark.parametrize('offset', range(1, 10))
def test_create_diagonal(self, n, offset, xp):
rng = np.random.default_rng(2347823)
one = xp.asarray(1.)
x = xp.asarray(rng.random(n), dtype=one.dtype)
A = xp_create_diagonal(x, offset=offset, xp=xp)
B = xp.asarray(np.diag(x, offset), dtype=one.dtype)
xp_assert_equal(A, B)
41 changes: 28 additions & 13 deletions scipy/optimize/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import warnings
import sys
import inspect
from numpy import atleast_1d, eye, argmin, zeros, shape, asarray, sqrt
from numpy import eye, argmin, zeros, shape, asarray, sqrt
import numpy as np
from scipy.linalg import cholesky, issymmetric, LinAlgError
from scipy.sparse.linalg import LinearOperator
Expand All @@ -41,6 +41,8 @@
from scipy._lib._util import (MapWrapper, check_random_state, _RichResult,
_call_callback_maybe_halt)
from scipy.optimize._differentiable_functions import ScalarFunction, FD_METHODS
from scipy._lib._array_api import (array_namespace, xp_atleast_nd,
xp_create_diagonal)


# standard status messages of optimizers
Expand Down Expand Up @@ -358,9 +360,12 @@ def rosen(x):
>>> ax.plot_surface(X, Y, rosen([X, Y]))
>>> plt.show()
"""
x = asarray(x)
r = np.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0,
axis=0)
xp = array_namespace(x)
x = xp.asarray(x)
if xp.isdtype(x.dtype, 'integral'):
x = xp.astype(x, xp.asarray(1.).dtype)
r = xp.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0,
axis=0, dtype=x.dtype)
return r


Expand Down Expand Up @@ -391,11 +396,14 @@ def rosen_der(x):
array([ -2. , 10.6, 15.6, 13.4, 6.4, -3. , -12.4, -19.4, 62. ])
"""
x = asarray(x)
xp = array_namespace(x)
x = xp.asarray(x)
if xp.isdtype(x.dtype, 'integral'):
x = xp.astype(x, xp.asarray(1.).dtype)
xm = x[1:-1]
xm_m1 = x[:-2]
xm_p1 = x[2:]
der = np.zeros_like(x)
der = xp.zeros_like(x)
der[1:-1] = (200 * (xm - xm_m1**2) -
400 * (xm_p1 - xm**2) * xm - 2 * (1 - xm))
der[0] = -400 * x[0] * (x[1] - x[0]**2) - 2 * (1 - x[0])
Expand Down Expand Up @@ -433,14 +441,17 @@ def rosen_hess(x):
[ 0., 0., -80., 200.]])
"""
x = atleast_1d(x)
H = np.diag(-400 * x[:-1], 1) - np.diag(400 * x[:-1], -1)
diagonal = np.zeros(len(x), dtype=x.dtype)
xp = array_namespace(x)
x = xp_atleast_nd(x, ndim=1, xp=xp)
if xp.isdtype(x.dtype, 'integral'):
x = xp.astype(x, xp.asarray(1.).dtype)
H = (xp_create_diagonal(-400 * x[:-1], offset=1, xp=xp)
- xp_create_diagonal(400 * x[:-1], offset=-1, xp=xp))
diagonal = xp.zeros(x.shape[0], dtype=x.dtype)
diagonal[0] = 1200 * x[0]**2 - 400 * x[1] + 2
diagonal[-1] = 200
diagonal[1:-1] = 202 + 1200 * x[1:-1]**2 - 400 * x[2:]
H = H + np.diag(diagonal)
return H
return H + xp_create_diagonal(diagonal, xp=xp)


def rosen_hess_prod(x, p):
Expand Down Expand Up @@ -474,8 +485,12 @@ def rosen_hess_prod(x, p):
array([ -0., 27., -10., -95., -192., -265., -278., -195., -180.])
"""
x = atleast_1d(x)
Hp = np.zeros(len(x), dtype=x.dtype)
xp = array_namespace(x, p)
x = xp_atleast_nd(x, ndim=1, xp=xp)
if xp.isdtype(x.dtype, 'integral'):
x = xp.astype(x, xp.asarray(1.).dtype)
p = xp.asarray(p, dtype=x.dtype)
Hp = xp.zeros(x.shape[0], dtype=x.dtype)
Hp[0] = (1200 * x[0]**2 - 400 * x[1] + 2) * p[0] - 400 * x[0] * p[1]
Hp[1:-1] = (-400 * x[:-2] * p[:-2] +
(202 + 1200 * x[1:-1]**2 - 400 * x[2:]) * p[1:-1] -
Expand Down
37 changes: 31 additions & 6 deletions scipy/optimize/tests/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@

from scipy.sparse import (coo_matrix, csc_matrix, csr_matrix, coo_array,
csr_array, csc_array)
from scipy.conftest import array_api_compatible
from scipy._lib._array_api_no_0d import xp_assert_equal, array_namespace

skip_xp_backends = pytest.mark.skip_xp_backends


def test_check_grad():
# Verify if check_grad is able to estimate the derivative of the
Expand Down Expand Up @@ -2428,15 +2433,35 @@ def test_powell_output():
assert np.isscalar(res.fun)


@array_api_compatible
class TestRosen:

def test_hess(self):
def test_rosen(self, xp):
# integer input should be promoted to the default floating type
x = xp.asarray([1, 1, 1])
xp_assert_equal(optimize.rosen(x),
xp.asarray(0.))

@skip_xp_backends('jax.numpy',
reasons=["JAX arrays do not support item assignment"])
@pytest.mark.usefixtures("skip_xp_backends")
def test_rosen_der(self, xp):
x = xp.asarray([1, 1, 1, 1])
xp_assert_equal(optimize.rosen_der(x),
xp.zeros_like(x, dtype=xp.asarray(1.).dtype))

@skip_xp_backends('jax.numpy',
reasons=["JAX arrays do not support item assignment"])
@pytest.mark.usefixtures("skip_xp_backends")
def test_hess_prod(self, xp):
one = xp.asarray(1.)
xp_test = array_namespace(one)
# Compare rosen_hess(x) times p with rosen_hess_prod(x,p). See gh-1775.
x = np.array([3, 4, 5])
p = np.array([2, 2, 2])
x = xp.asarray([3, 4, 5])
p = xp.asarray([2, 2, 2])
hp = optimize.rosen_hess_prod(x, p)
dothp = np.dot(optimize.rosen_hess(x), p)
assert_equal(hp, dothp)
p = xp_test.astype(p, one.dtype)
dothp = optimize.rosen_hess(x) @ p
xp_assert_equal(hp, dothp)


def himmelblau(p):
Expand Down

0 comments on commit 10807c0

Please sign in to comment.