Skip to content

Commit

Permalink
Update dpnp.linalg.solve() to align NumPy 2.0 (#2198)
Browse files Browse the repository at this point in the history
* Update solve with broadcasting to align numpy 2.0

* Update and add more tests for solve()

* Keep only solve() logic for numpy 2.0 compatibility

* Update cupy tests for solve()

* Align TestSolve with cupy tests

* Cover case b.ndim==0

* Add notes for solve()
  • Loading branch information
vlad-perevezentsev authored and vtavana committed Dec 2, 2024
1 parent 319f564 commit 9edf4f6
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 52 deletions.
46 changes: 38 additions & 8 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,18 +1612,24 @@ def solve(a, b):
----------
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
Coefficient matrix.
b : {(…, M,), (, M, K)} {dpnp.ndarray, usm_ndarray}
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
Ordinate or "dependent variable" values.
Returns
-------
out : {(, M,), (, M, K)} dpnp.ndarray
out : {(..., M,), (..., M, K)} dpnp.ndarray
Solution to the system `ax = b`. Returned shape is identical to `b`.
See Also
--------
:obj:`dpnp.dot` : Returns the dot product of two arrays.
Notes
-----
The `b` array is only treated as a shape (M,) column vector if it is
exactly 1-dimensional. In all other instances it is treated as a stack
of (M, K) matrices.
Examples
--------
>>> import dpnp as dp
Expand All @@ -1644,14 +1650,38 @@ def solve(a, b):
assert_stacked_2d(a)
assert_stacked_square(a)

if not (
a.ndim in [b.ndim, b.ndim + 1] and a.shape[:-1] == b.shape[: a.ndim - 1]
):
raise dpnp.linalg.LinAlgError(
"a must have (..., M, M) shape and b must have (..., M) "
"or (..., M, K)"
a_shape = a.shape
b_shape = b.shape
b_ndim = b.ndim

# compatible with numpy>=2.0
if b_ndim == 0:
raise ValueError("b must have at least one dimension")
if b_ndim == 1:
if a_shape[-1] != b.size:
raise ValueError(
"a must have (..., M, M) shape and b must have (M,) "
"for one-dimensional b"
)
b = dpnp.broadcast_to(b, a_shape[:-1])
return dpnp_solve(a, b)

if a_shape[-1] != b_shape[-2]:
raise ValueError(
"a must have (..., M, M) shape and b must have (..., M, K) shape"
)

# Use dpnp.broadcast_shapes() to align the resulting batch shapes
broadcasted_batch_shape = dpnp.broadcast_shapes(a_shape[:-2], b_shape[:-2])

a_broadcasted_shape = broadcasted_batch_shape + a_shape[-2:]
b_broadcasted_shape = broadcasted_batch_shape + b_shape[-2:]

if a_shape != a_broadcasted_shape:
a = dpnp.broadcast_to(a, a_broadcasted_shape)
if b_shape != b_broadcasted_shape:
b = dpnp.broadcast_to(b, b_broadcasted_shape)

return dpnp_solve(a, b)


Expand Down
34 changes: 34 additions & 0 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2694,6 +2694,36 @@ def test_solve(self, dtype):

assert_allclose(expected, result, rtol=1e-06)

@testing.with_requires("numpy>=2.0")
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@pytest.mark.parametrize(
"a_shape, b_shape",
[
((4, 4), (2, 2, 4, 3)),
((2, 5, 5), (1, 5, 3)),
((2, 4, 4), (2, 2, 4, 2)),
((3, 2, 2), (3, 1, 2, 1)),
((2, 2, 2, 2, 2), (2,)),
((2, 2, 2, 2, 2), (2, 3)),
],
)
def test_solve_broadcast(self, a_shape, b_shape, dtype):
# Set seed_value=81 to prevent
# random generation of the input singular matrix
a_np = generate_random_numpy_array(a_shape, dtype, seed_value=81)

# Set seed_value=76 to prevent
# random generation of the input singular matrix
b_np = generate_random_numpy_array(b_shape, dtype, seed_value=76)

a_dp = inp.array(a_np)
b_dp = inp.array(b_np)

expected = numpy.linalg.solve(a_np, b_np)
result = inp.linalg.solve(a_dp, b_dp)

assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
def test_solve_nrhs_greater_n(self, dtype):
# Test checking the case when nrhs > n for
Expand Down Expand Up @@ -2800,6 +2830,10 @@ def test_solve_errors(self):
inp.linalg.LinAlgError, inp.linalg.solve, a_dp_ndim_1, b_dp
)

# b.ndim == 0
b_dp_ndim_0 = inp.array(2)
assert_raises(ValueError, inp.linalg.solve, a_dp, b_dp_ndim_0)


class TestSlogdet:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
Expand Down
30 changes: 12 additions & 18 deletions dpnp/tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,40 +2392,34 @@ def test_where(device):
ids=[device.filter_string for device in valid_devices],
)
@pytest.mark.parametrize(
"matrix, vector",
"matrix, rhs",
[
([[1, 2], [3, 5]], numpy.empty((2, 0))),
([[1, 2], [3, 5]], [1, 2]),
(
[
[[1, 1, 1], [0, 2, 5], [2, 5, -1]],
[[3, -1, 1], [1, 2, 3], [2, 3, 1]],
[[1, 4, 1], [1, 2, -2], [4, 1, 2]],
[[1, 1], [0, 2]],
[[3, -1], [1, 2]],
],
[
[[6, -4], [9, -6]],
[[15, 1], [15, 1]],
],
[[6, -4, 27], [9, -6, 15], [15, 1, 11]],
),
],
ids=[
"2D_Matrix_Empty_Vector",
"2D_Matrix_1D_Vector",
"3D_Matrix_and_Vectors",
"2D_Matrix_Empty_RHS",
"2D_Matrix_1D_RHS",
"3D_Matrix_and_3D_RHS",
],
)
def test_solve(matrix, vector, device):
def test_solve(matrix, rhs, device):
a_np = numpy.array(matrix)
b_np = numpy.array(vector)
b_np = numpy.array(rhs)

a_dp = dpnp.array(a_np, device=device)
b_dp = dpnp.array(b_np, device=device)

# In numpy 2.0 the broadcast ambiguity has been removed and now
# b is treaded as a single vector if and only if it is 1-dimensional;
# for other cases this signature must be followed
# (..., m, m), (..., m, n) -> (..., m, n)
# https://github.com/numpy/numpy/pull/25914
if a_dp.ndim > 2 and numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
pytest.skip("SAT-6928")

result = dpnp.linalg.solve(a_dp, b_dp)
expected = numpy.linalg.solve(a_np, b_np)
assert_dtype_allclose(result, expected)
Expand Down
30 changes: 16 additions & 14 deletions dpnp/tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,37 +1285,39 @@ def test_fftshift(self, func, usm_type):
"usm_type_matrix", list_of_usm_types, ids=list_of_usm_types
)
@pytest.mark.parametrize(
"usm_type_vector", list_of_usm_types, ids=list_of_usm_types
"usm_type_rhs", list_of_usm_types, ids=list_of_usm_types
)
@pytest.mark.parametrize(
"matrix, vector",
"matrix, rhs",
[
([[1, 2], [3, 5]], dp.empty((2, 0))),
([[1, 2], [3, 5]], numpy.empty((2, 0))),
([[1, 2], [3, 5]], [1, 2]),
(
[
[[1, 1, 1], [0, 2, 5], [2, 5, -1]],
[[3, -1, 1], [1, 2, 3], [2, 3, 1]],
[[1, 4, 1], [1, 2, -2], [4, 1, 2]],
[[1, 1], [0, 2]],
[[3, -1], [1, 2]],
],
[
[[6, -4], [9, -6]],
[[15, 1], [15, 1]],
],
[[6, -4, 27], [9, -6, 15], [15, 1, 11]],
),
],
ids=[
"2D_Matrix_Empty_Vector",
"2D_Matrix_1D_Vector",
"3D_Matrix_and_Vectors",
"2D_Matrix_Empty_RHS",
"2D_Matrix_1D_RHS",
"3D_Matrix_and_3D_RHS",
],
)
def test_solve(matrix, vector, usm_type_matrix, usm_type_vector):
def test_solve(matrix, rhs, usm_type_matrix, usm_type_rhs):
x = dp.array(matrix, usm_type=usm_type_matrix)
y = dp.array(vector, usm_type=usm_type_vector)
y = dp.array(rhs, usm_type=usm_type_rhs)
z = dp.linalg.solve(x, y)

assert x.usm_type == usm_type_matrix
assert y.usm_type == usm_type_vector
assert y.usm_type == usm_type_rhs
assert z.usm_type == du.get_coerced_usm_type(
[usm_type_matrix, usm_type_vector]
[usm_type_matrix, usm_type_rhs]
)


Expand Down
21 changes: 9 additions & 12 deletions dpnp/tests/third_party/cupy/linalg_tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def check_x(self, a_shape, b_shape, xp, dtype):
testing.assert_array_equal(b_copy, b)
return result

@testing.with_requires("numpy>=2.0")
def test_solve(self):
self.check_x((4, 4), (4,))
self.check_x((5, 5), (5, 2))
Expand All @@ -55,15 +56,9 @@ def test_solve(self):
self.check_x((0, 0), (0,))
self.check_x((0, 0), (0, 2))
self.check_x((0, 2, 2), (0, 2, 3))
# In numpy 2.0 the broadcast ambiguity has been removed and now
# b is treaded as a single vector if and only if it is 1-dimensional;
# for other cases this signature must be followed
# (..., m, m), (..., m, n) -> (..., m, n)
# https://github.com/numpy/numpy/pull/25914
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
self.check_x((2, 4, 4), (2, 4))
self.check_x((2, 3, 2, 2), (2, 3, 2))
self.check_x((0, 2, 2), (0, 2))
# Allowed since numpy 2
self.check_x((2, 3, 3), (3,))
self.check_x((2, 5, 3, 3), (3,))

def check_shape(self, a_shape, b_shape, error_types):
for xp, error_type in error_types.items():
Expand All @@ -82,6 +77,7 @@ def test_solve_singular_empty(self, xp):
# LinAlgError("Singular matrix") is not raised
return xp.linalg.solve(a, b)

@testing.with_requires("numpy>=2.0")
def test_invalid_shape(self):
linalg_errors = {
numpy: numpy.linalg.LinAlgError,
Expand All @@ -96,11 +92,12 @@ def test_invalid_shape(self):
self.check_shape((3, 3), (2,), value_errors)
self.check_shape((3, 3), (2, 2), value_errors)
self.check_shape((3, 3, 4), (3,), linalg_errors)
# Since numpy >= 2.0, this case does not raise an error
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
self.check_shape((2, 3, 3), (3,), value_errors)
self.check_shape((3, 3), (0,), value_errors)
self.check_shape((0, 3, 4), (3,), linalg_errors)
# Not allowed since numpy 2.0
self.check_shape((0, 2, 2), (0, 2), value_errors)
self.check_shape((2, 4, 4), (2, 4), value_errors)
self.check_shape((2, 3, 2, 2), (2, 3, 2), value_errors)


@testing.parameterize(
Expand Down

0 comments on commit 9edf4f6

Please sign in to comment.