From b83a3781a1fed961bea53c2edfa676d1e4b96968 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 22 Feb 2024 15:28:09 -0700 Subject: [PATCH 1/3] Workaround np.linalg.solve ambiguity NumPy's solve() does not handle the ambiguity around x2 being 1-D vector vs. an n-D stack of matrices in the way that the standard specifies. Namely, x2 should be treated as a 1-D vector iff it is 1-dimensional, and a stack of matrices in all other cases. This workaround is borrowed from array-api-strict. See https://github.com/numpy/numpy/issues/15349 and https://github.com/data-apis/array-api/issues/285. Note that this workaround only works for NumPy. CuPy currently does not support stacked vectors for solve() (see https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43), and the workaround in cupy.array_api.linalg does not seem to actually function. --- array_api_compat/numpy/linalg.py | 48 ++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index fbe37cd7..0861c2f0 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -33,6 +33,53 @@ vector_norm, ) +# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a +# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack +# of matrices. The np.linalg.solve behavior of allowing stacks of both +# matrices and vectors is ambiguous c.f. +# https://github.com/numpy/numpy/issues/15349 and +# https://github.com/data-apis/array-api/issues/285. + +# To workaround this, the below is the code from np.linalg.solve except +# only calling solve1 in the exactly 1D case. + +# This code is here instead of in common because it is numpy specific. Also +# note that CuPy's solve() does not currently support broadcasting (see +# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). +def solve(x1: _np.ndarray, x2: _np.ndarray, /) -> _np.ndarray: + try: + from numpy.linalg._linalg import ( + _makearray, _assert_stacked_2d, _assert_stacked_square, + _commonType, isComplexType, _raise_linalgerror_singular + ) + except ImportError: + from numpy.linalg.linalg import ( + _makearray, _assert_stacked_2d, _assert_stacked_square, + _commonType, isComplexType, _raise_linalgerror_singular + ) + from numpy.linalg import _umath_linalg + + x1, _ = _makearray(x1) + _assert_stacked_2d(x1) + _assert_stacked_square(x1) + x2, wrap = _makearray(x2) + t, result_t = _commonType(x1, x2) + + # This part is different from np.linalg.solve + if x2.ndim == 1: + gufunc = _umath_linalg.solve1 + else: + gufunc = _umath_linalg.solve + + # This does nothing currently but is left in because it will be relevant + # when complex dtype support is added to the spec in 2022. + signature = 'DD->D' if isComplexType(t) else 'dd->d' + with _np.errstate(call=_raise_linalgerror_singular, invalid='call', + over='ignore', divide='ignore', under='ignore'): + r = gufunc(x1, x2, signature=signature) + + return wrap(r.astype(result_t, copy=False)) + __all__ = [] __all__ += _numpy_linalg_all @@ -54,6 +101,7 @@ "pinv", "qr", "slogdet", + "solve", "svd", "svdvals", "tensordot", From 145a2021120759b8952c404889b6882176888dc3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 22 Feb 2024 16:16:49 -0700 Subject: [PATCH 2/3] Fix missing newline at end of file --- array_api_compat/common/_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 0708b76a..28a7d644 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -149,4 +149,4 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra dtype = xp.float64 elif x.dtype == xp.complex64: dtype = xp.complex128 - return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) \ No newline at end of file + return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs)) From 730a21459420a29e63b3307a6d0c0c7ab3082344 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 27 Feb 2024 15:51:17 -0700 Subject: [PATCH 3/3] Remove numpy linalg xfails --- numpy-xfails.txt | 6 ------ 1 file changed, 6 deletions(-) diff --git a/numpy-xfails.txt b/numpy-xfails.txt index c541602b..d0be245b 100644 --- a/numpy-xfails.txt +++ b/numpy-xfails.txt @@ -10,12 +10,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-to_device] array_api_tests/test_has_names.py::test_has_names[array_attribute-device] array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] -# linalg tests require https://github.com/data-apis/array-api-tests/pull/101 -# cleanups. Also some tests are using .mT -array_api_tests/test_linalg.py::test_eigvalsh -array_api_tests/test_linalg.py::test_solve -array_api_tests/test_linalg.py::test_trace - # Array methods and attributes not already on np.ndarray cannot be wrapped array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__] array_api_tests/test_signatures.py::test_array_method_signature[to_device]