From 9edf4f624408d2caef0c5486e4ec9bcf6e4388b1 Mon Sep 17 00:00:00 2001 From: vlad-perevezentsev Date: Thu, 28 Nov 2024 13:13:52 +0100 Subject: [PATCH] Update `dpnp.linalg.solve()` to align NumPy 2.0 (#2198) * Update solve with broadcasting to align numpy 2.0 * Update and add more tests for solve() * Keep only solve() logic for numpy 2.0 compatibility * Update cupy tests for solve() * Align TestSolve with cupy tests * Cover case b.ndim==0 * Add notes for solve() --- dpnp/linalg/dpnp_iface_linalg.py | 46 +++++++++++++++---- dpnp/tests/test_linalg.py | 34 ++++++++++++++ dpnp/tests/test_sycl_queue.py | 30 +++++------- dpnp/tests/test_usm_type.py | 30 ++++++------ .../cupy/linalg_tests/test_solve.py | 21 ++++----- 5 files changed, 109 insertions(+), 52 deletions(-) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 51b64b4bc37..3a5c1a2abaf 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -1612,18 +1612,24 @@ def solve(a, b): ---------- a : (..., M, M) {dpnp.ndarray, usm_ndarray} Coefficient matrix. - b : {(…, M,), (…, M, K)} {dpnp.ndarray, usm_ndarray} + b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} Ordinate or "dependent variable" values. Returns ------- - out : {(…, M,), (…, M, K)} dpnp.ndarray + out : {(..., M,), (..., M, K)} dpnp.ndarray Solution to the system `ax = b`. Returned shape is identical to `b`. See Also -------- :obj:`dpnp.dot` : Returns the dot product of two arrays. + Notes + ----- + The `b` array is only treated as a shape (M,) column vector if it is + exactly 1-dimensional. In all other instances it is treated as a stack + of (M, K) matrices. + Examples -------- >>> import dpnp as dp @@ -1644,14 +1650,38 @@ def solve(a, b): assert_stacked_2d(a) assert_stacked_square(a) - if not ( - a.ndim in [b.ndim, b.ndim + 1] and a.shape[:-1] == b.shape[: a.ndim - 1] - ): - raise dpnp.linalg.LinAlgError( - "a must have (..., M, M) shape and b must have (..., M) " - "or (..., M, K)" + a_shape = a.shape + b_shape = b.shape + b_ndim = b.ndim + + # compatible with numpy>=2.0 + if b_ndim == 0: + raise ValueError("b must have at least one dimension") + if b_ndim == 1: + if a_shape[-1] != b.size: + raise ValueError( + "a must have (..., M, M) shape and b must have (M,) " + "for one-dimensional b" + ) + b = dpnp.broadcast_to(b, a_shape[:-1]) + return dpnp_solve(a, b) + + if a_shape[-1] != b_shape[-2]: + raise ValueError( + "a must have (..., M, M) shape and b must have (..., M, K) shape" ) + # Use dpnp.broadcast_shapes() to align the resulting batch shapes + broadcasted_batch_shape = dpnp.broadcast_shapes(a_shape[:-2], b_shape[:-2]) + + a_broadcasted_shape = broadcasted_batch_shape + a_shape[-2:] + b_broadcasted_shape = broadcasted_batch_shape + b_shape[-2:] + + if a_shape != a_broadcasted_shape: + a = dpnp.broadcast_to(a, a_broadcasted_shape) + if b_shape != b_broadcasted_shape: + b = dpnp.broadcast_to(b, b_broadcasted_shape) + return dpnp_solve(a, b) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index 76ce7989c8c..a0ce71b8208 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2694,6 +2694,36 @@ def test_solve(self, dtype): assert_allclose(expected, result, rtol=1e-06) + @testing.with_requires("numpy>=2.0") + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + @pytest.mark.parametrize( + "a_shape, b_shape", + [ + ((4, 4), (2, 2, 4, 3)), + ((2, 5, 5), (1, 5, 3)), + ((2, 4, 4), (2, 2, 4, 2)), + ((3, 2, 2), (3, 1, 2, 1)), + ((2, 2, 2, 2, 2), (2,)), + ((2, 2, 2, 2, 2), (2, 3)), + ], + ) + def test_solve_broadcast(self, a_shape, b_shape, dtype): + # Set seed_value=81 to prevent + # random generation of the input singular matrix + a_np = generate_random_numpy_array(a_shape, dtype, seed_value=81) + + # Set seed_value=76 to prevent + # random generation of the input singular matrix + b_np = generate_random_numpy_array(b_shape, dtype, seed_value=76) + + a_dp = inp.array(a_np) + b_dp = inp.array(b_np) + + expected = numpy.linalg.solve(a_np, b_np) + result = inp.linalg.solve(a_dp, b_dp) + + assert_dtype_allclose(result, expected) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) def test_solve_nrhs_greater_n(self, dtype): # Test checking the case when nrhs > n for @@ -2800,6 +2830,10 @@ def test_solve_errors(self): inp.linalg.LinAlgError, inp.linalg.solve, a_dp_ndim_1, b_dp ) + # b.ndim == 0 + b_dp_ndim_0 = inp.array(2) + assert_raises(ValueError, inp.linalg.solve, a_dp, b_dp_ndim_0) + class TestSlogdet: @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index 359f99de048..74ec14c1b82 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -2392,40 +2392,34 @@ def test_where(device): ids=[device.filter_string for device in valid_devices], ) @pytest.mark.parametrize( - "matrix, vector", + "matrix, rhs", [ ([[1, 2], [3, 5]], numpy.empty((2, 0))), ([[1, 2], [3, 5]], [1, 2]), ( [ - [[1, 1, 1], [0, 2, 5], [2, 5, -1]], - [[3, -1, 1], [1, 2, 3], [2, 3, 1]], - [[1, 4, 1], [1, 2, -2], [4, 1, 2]], + [[1, 1], [0, 2]], + [[3, -1], [1, 2]], + ], + [ + [[6, -4], [9, -6]], + [[15, 1], [15, 1]], ], - [[6, -4, 27], [9, -6, 15], [15, 1, 11]], ), ], ids=[ - "2D_Matrix_Empty_Vector", - "2D_Matrix_1D_Vector", - "3D_Matrix_and_Vectors", + "2D_Matrix_Empty_RHS", + "2D_Matrix_1D_RHS", + "3D_Matrix_and_3D_RHS", ], ) -def test_solve(matrix, vector, device): +def test_solve(matrix, rhs, device): a_np = numpy.array(matrix) - b_np = numpy.array(vector) + b_np = numpy.array(rhs) a_dp = dpnp.array(a_np, device=device) b_dp = dpnp.array(b_np, device=device) - # In numpy 2.0 the broadcast ambiguity has been removed and now - # b is treaded as a single vector if and only if it is 1-dimensional; - # for other cases this signature must be followed - # (..., m, m), (..., m, n) -> (..., m, n) - # https://github.com/numpy/numpy/pull/25914 - if a_dp.ndim > 2 and numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0": - pytest.skip("SAT-6928") - result = dpnp.linalg.solve(a_dp, b_dp) expected = numpy.linalg.solve(a_np, b_np) assert_dtype_allclose(result, expected) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index d14604be725..8e06639a97c 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1285,37 +1285,39 @@ def test_fftshift(self, func, usm_type): "usm_type_matrix", list_of_usm_types, ids=list_of_usm_types ) @pytest.mark.parametrize( - "usm_type_vector", list_of_usm_types, ids=list_of_usm_types + "usm_type_rhs", list_of_usm_types, ids=list_of_usm_types ) @pytest.mark.parametrize( - "matrix, vector", + "matrix, rhs", [ - ([[1, 2], [3, 5]], dp.empty((2, 0))), + ([[1, 2], [3, 5]], numpy.empty((2, 0))), ([[1, 2], [3, 5]], [1, 2]), ( [ - [[1, 1, 1], [0, 2, 5], [2, 5, -1]], - [[3, -1, 1], [1, 2, 3], [2, 3, 1]], - [[1, 4, 1], [1, 2, -2], [4, 1, 2]], + [[1, 1], [0, 2]], + [[3, -1], [1, 2]], + ], + [ + [[6, -4], [9, -6]], + [[15, 1], [15, 1]], ], - [[6, -4, 27], [9, -6, 15], [15, 1, 11]], ), ], ids=[ - "2D_Matrix_Empty_Vector", - "2D_Matrix_1D_Vector", - "3D_Matrix_and_Vectors", + "2D_Matrix_Empty_RHS", + "2D_Matrix_1D_RHS", + "3D_Matrix_and_3D_RHS", ], ) -def test_solve(matrix, vector, usm_type_matrix, usm_type_vector): +def test_solve(matrix, rhs, usm_type_matrix, usm_type_rhs): x = dp.array(matrix, usm_type=usm_type_matrix) - y = dp.array(vector, usm_type=usm_type_vector) + y = dp.array(rhs, usm_type=usm_type_rhs) z = dp.linalg.solve(x, y) assert x.usm_type == usm_type_matrix - assert y.usm_type == usm_type_vector + assert y.usm_type == usm_type_rhs assert z.usm_type == du.get_coerced_usm_type( - [usm_type_matrix, usm_type_vector] + [usm_type_matrix, usm_type_rhs] ) diff --git a/dpnp/tests/third_party/cupy/linalg_tests/test_solve.py b/dpnp/tests/third_party/cupy/linalg_tests/test_solve.py index 697a977a648..5fb6533be33 100644 --- a/dpnp/tests/third_party/cupy/linalg_tests/test_solve.py +++ b/dpnp/tests/third_party/cupy/linalg_tests/test_solve.py @@ -47,6 +47,7 @@ def check_x(self, a_shape, b_shape, xp, dtype): testing.assert_array_equal(b_copy, b) return result + @testing.with_requires("numpy>=2.0") def test_solve(self): self.check_x((4, 4), (4,)) self.check_x((5, 5), (5, 2)) @@ -55,15 +56,9 @@ def test_solve(self): self.check_x((0, 0), (0,)) self.check_x((0, 0), (0, 2)) self.check_x((0, 2, 2), (0, 2, 3)) - # In numpy 2.0 the broadcast ambiguity has been removed and now - # b is treaded as a single vector if and only if it is 1-dimensional; - # for other cases this signature must be followed - # (..., m, m), (..., m, n) -> (..., m, n) - # https://github.com/numpy/numpy/pull/25914 - if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0": - self.check_x((2, 4, 4), (2, 4)) - self.check_x((2, 3, 2, 2), (2, 3, 2)) - self.check_x((0, 2, 2), (0, 2)) + # Allowed since numpy 2 + self.check_x((2, 3, 3), (3,)) + self.check_x((2, 5, 3, 3), (3,)) def check_shape(self, a_shape, b_shape, error_types): for xp, error_type in error_types.items(): @@ -82,6 +77,7 @@ def test_solve_singular_empty(self, xp): # LinAlgError("Singular matrix") is not raised return xp.linalg.solve(a, b) + @testing.with_requires("numpy>=2.0") def test_invalid_shape(self): linalg_errors = { numpy: numpy.linalg.LinAlgError, @@ -96,11 +92,12 @@ def test_invalid_shape(self): self.check_shape((3, 3), (2,), value_errors) self.check_shape((3, 3), (2, 2), value_errors) self.check_shape((3, 3, 4), (3,), linalg_errors) - # Since numpy >= 2.0, this case does not raise an error - if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0": - self.check_shape((2, 3, 3), (3,), value_errors) self.check_shape((3, 3), (0,), value_errors) self.check_shape((0, 3, 4), (3,), linalg_errors) + # Not allowed since numpy 2.0 + self.check_shape((0, 2, 2), (0, 2), value_errors) + self.check_shape((2, 4, 4), (2, 4), value_errors) + self.check_shape((2, 3, 2, 2), (2, 3, 2), value_errors) @testing.parameterize(