From 10807c0b1fa1f5ff57bc7866a88e738e4d58aee8 Mon Sep 17 00:00:00 2001 From: Jake Bowhay <60778417+j-bowhay@users.noreply.github.com> Date: Fri, 1 Nov 2024 18:08:52 -0600 Subject: [PATCH] ENH: optimize: add array api support to `rosen` and friends (#21778) * ENH: optimize: add array api support to rosen and friends --- .github/workflows/array_api.yml | 1 + scipy/_lib/_array_api.py | 10 +++++++ scipy/_lib/tests/test_array_api.py | 17 ++++++++++- scipy/optimize/_optimize.py | 41 ++++++++++++++++++--------- scipy/optimize/tests/test_optimize.py | 37 ++++++++++++++++++++---- 5 files changed, 86 insertions(+), 20 deletions(-) diff --git a/.github/workflows/array_api.yml b/.github/workflows/array_api.yml index 587968e33d1b..2a472d8cba2f 100644 --- a/.github/workflows/array_api.yml +++ b/.github/workflows/array_api.yml @@ -27,6 +27,7 @@ env: -t scipy.integrate.tests.test_tanhsinh -t scipy.integrate.tests.test_cubature -t scipy.optimize.tests.test_chandrupatla + -t scipy.optimize.tests.test_optimize -t scipy.stats -t scipy.ndimage diff --git a/scipy/_lib/_array_api.py b/scipy/_lib/_array_api.py index 57a0c13c8f28..799464cf4d86 100644 --- a/scipy/_lib/_array_api.py +++ b/scipy/_lib/_array_api.py @@ -34,6 +34,7 @@ 'xp_atleast_nd', 'xp_copy', 'xp_copysign', 'xp_device', 'xp_moveaxis_to_end', 'xp_ravel', 'xp_real', 'xp_sign', 'xp_size', 'xp_take_along_axis', 'xp_unsupported_param_msg', 'xp_vector_norm', + 'xp_create_diagonal' ] @@ -645,3 +646,12 @@ def xp_float_to_complex(arr: Array, xp: ModuleType | None = None) -> Array: arr = xp.astype(arr, xp.complex128) return arr + +def xp_create_diagonal(x: Array, /, *, offset: int = 0, + xp: ModuleType | None = None) -> Array: + xp = array_namespace(x) if xp is None else xp + n = x.shape[0] + abs(offset) + diag = xp.zeros(n**2, dtype=x.dtype) + i = offset if offset >= 0 else abs(offset) * n + diag[i:min(n*(n-offset), diag.shape[0]):n+1] = x + return xp.reshape(diag, (n, n)) diff --git a/scipy/_lib/tests/test_array_api.py b/scipy/_lib/tests/test_array_api.py index f1eaca57bdd7..f532d7a7c22e 100644 --- a/scipy/_lib/tests/test_array_api.py +++ b/scipy/_lib/tests/test_array_api.py @@ -3,7 +3,8 @@ from scipy.conftest import array_api_compatible from scipy._lib._array_api import ( - _GLOBAL_CONFIG, array_namespace, _asarray, xp_copy, xp_assert_equal, is_numpy + _GLOBAL_CONFIG, array_namespace, _asarray, xp_copy, xp_assert_equal, is_numpy, + xp_create_diagonal ) from scipy._lib._array_api_no_0d import xp_assert_equal as xp_assert_equal_no_0d import scipy._lib.array_api_compat.numpy as np_compat @@ -185,3 +186,17 @@ def test_check_scalar_no_0d(self, xp): # scalars-vs-0d passes (if values match) also with regular python objects xp_assert_equal_no_0d(0., xp.asarray(0.)) xp_assert_equal_no_0d(42, xp.asarray(42)) + + @skip_xp_backends('jax.numpy', + reasons=["JAX arrays do not support item assignment"]) + @pytest.mark.usefixtures("skip_xp_backends") + @array_api_compatible + @pytest.mark.parametrize('n', range(1, 10)) + @pytest.mark.parametrize('offset', range(1, 10)) + def test_create_diagonal(self, n, offset, xp): + rng = np.random.default_rng(2347823) + one = xp.asarray(1.) + x = xp.asarray(rng.random(n), dtype=one.dtype) + A = xp_create_diagonal(x, offset=offset, xp=xp) + B = xp.asarray(np.diag(x, offset), dtype=one.dtype) + xp_assert_equal(A, B) diff --git a/scipy/optimize/_optimize.py b/scipy/optimize/_optimize.py index 8fefe870efc4..e4d77ec79692 100644 --- a/scipy/optimize/_optimize.py +++ b/scipy/optimize/_optimize.py @@ -29,7 +29,7 @@ import warnings import sys import inspect -from numpy import atleast_1d, eye, argmin, zeros, shape, asarray, sqrt +from numpy import eye, argmin, zeros, shape, asarray, sqrt import numpy as np from scipy.linalg import cholesky, issymmetric, LinAlgError from scipy.sparse.linalg import LinearOperator @@ -41,6 +41,8 @@ from scipy._lib._util import (MapWrapper, check_random_state, _RichResult, _call_callback_maybe_halt) from scipy.optimize._differentiable_functions import ScalarFunction, FD_METHODS +from scipy._lib._array_api import (array_namespace, xp_atleast_nd, + xp_create_diagonal) # standard status messages of optimizers @@ -358,9 +360,12 @@ def rosen(x): >>> ax.plot_surface(X, Y, rosen([X, Y])) >>> plt.show() """ - x = asarray(x) - r = np.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0, - axis=0) + xp = array_namespace(x) + x = xp.asarray(x) + if xp.isdtype(x.dtype, 'integral'): + x = xp.astype(x, xp.asarray(1.).dtype) + r = xp.sum(100.0 * (x[1:] - x[:-1]**2.0)**2.0 + (1 - x[:-1])**2.0, + axis=0, dtype=x.dtype) return r @@ -391,11 +396,14 @@ def rosen_der(x): array([ -2. , 10.6, 15.6, 13.4, 6.4, -3. , -12.4, -19.4, 62. ]) """ - x = asarray(x) + xp = array_namespace(x) + x = xp.asarray(x) + if xp.isdtype(x.dtype, 'integral'): + x = xp.astype(x, xp.asarray(1.).dtype) xm = x[1:-1] xm_m1 = x[:-2] xm_p1 = x[2:] - der = np.zeros_like(x) + der = xp.zeros_like(x) der[1:-1] = (200 * (xm - xm_m1**2) - 400 * (xm_p1 - xm**2) * xm - 2 * (1 - xm)) der[0] = -400 * x[0] * (x[1] - x[0]**2) - 2 * (1 - x[0]) @@ -433,14 +441,17 @@ def rosen_hess(x): [ 0., 0., -80., 200.]]) """ - x = atleast_1d(x) - H = np.diag(-400 * x[:-1], 1) - np.diag(400 * x[:-1], -1) - diagonal = np.zeros(len(x), dtype=x.dtype) + xp = array_namespace(x) + x = xp_atleast_nd(x, ndim=1, xp=xp) + if xp.isdtype(x.dtype, 'integral'): + x = xp.astype(x, xp.asarray(1.).dtype) + H = (xp_create_diagonal(-400 * x[:-1], offset=1, xp=xp) + - xp_create_diagonal(400 * x[:-1], offset=-1, xp=xp)) + diagonal = xp.zeros(x.shape[0], dtype=x.dtype) diagonal[0] = 1200 * x[0]**2 - 400 * x[1] + 2 diagonal[-1] = 200 diagonal[1:-1] = 202 + 1200 * x[1:-1]**2 - 400 * x[2:] - H = H + np.diag(diagonal) - return H + return H + xp_create_diagonal(diagonal, xp=xp) def rosen_hess_prod(x, p): @@ -474,8 +485,12 @@ def rosen_hess_prod(x, p): array([ -0., 27., -10., -95., -192., -265., -278., -195., -180.]) """ - x = atleast_1d(x) - Hp = np.zeros(len(x), dtype=x.dtype) + xp = array_namespace(x, p) + x = xp_atleast_nd(x, ndim=1, xp=xp) + if xp.isdtype(x.dtype, 'integral'): + x = xp.astype(x, xp.asarray(1.).dtype) + p = xp.asarray(p, dtype=x.dtype) + Hp = xp.zeros(x.shape[0], dtype=x.dtype) Hp[0] = (1200 * x[0]**2 - 400 * x[1] + 2) * p[0] - 400 * x[0] * p[1] Hp[1:-1] = (-400 * x[:-2] * p[:-2] + (202 + 1200 * x[1:-1]**2 - 400 * x[2:]) * p[1:-1] - diff --git a/scipy/optimize/tests/test_optimize.py b/scipy/optimize/tests/test_optimize.py index f3c5b85cf93d..7eaaf7ff8bc9 100644 --- a/scipy/optimize/tests/test_optimize.py +++ b/scipy/optimize/tests/test_optimize.py @@ -32,6 +32,11 @@ from scipy.sparse import (coo_matrix, csc_matrix, csr_matrix, coo_array, csr_array, csc_array) +from scipy.conftest import array_api_compatible +from scipy._lib._array_api_no_0d import xp_assert_equal, array_namespace + +skip_xp_backends = pytest.mark.skip_xp_backends + def test_check_grad(): # Verify if check_grad is able to estimate the derivative of the @@ -2428,15 +2433,35 @@ def test_powell_output(): assert np.isscalar(res.fun) +@array_api_compatible class TestRosen: - - def test_hess(self): + def test_rosen(self, xp): + # integer input should be promoted to the default floating type + x = xp.asarray([1, 1, 1]) + xp_assert_equal(optimize.rosen(x), + xp.asarray(0.)) + + @skip_xp_backends('jax.numpy', + reasons=["JAX arrays do not support item assignment"]) + @pytest.mark.usefixtures("skip_xp_backends") + def test_rosen_der(self, xp): + x = xp.asarray([1, 1, 1, 1]) + xp_assert_equal(optimize.rosen_der(x), + xp.zeros_like(x, dtype=xp.asarray(1.).dtype)) + + @skip_xp_backends('jax.numpy', + reasons=["JAX arrays do not support item assignment"]) + @pytest.mark.usefixtures("skip_xp_backends") + def test_hess_prod(self, xp): + one = xp.asarray(1.) + xp_test = array_namespace(one) # Compare rosen_hess(x) times p with rosen_hess_prod(x,p). See gh-1775. - x = np.array([3, 4, 5]) - p = np.array([2, 2, 2]) + x = xp.asarray([3, 4, 5]) + p = xp.asarray([2, 2, 2]) hp = optimize.rosen_hess_prod(x, p) - dothp = np.dot(optimize.rosen_hess(x), p) - assert_equal(hp, dothp) + p = xp_test.astype(p, one.dtype) + dothp = optimize.rosen_hess(x) @ p + xp_assert_equal(hp, dothp) def himmelblau(p):