Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dpnp.linalg.solve() to align NumPy 2.0 #2198

Merged
merged 14 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,18 +1612,24 @@ def solve(a, b):
----------
a : (..., M, M) {dpnp.ndarray, usm_ndarray}
Coefficient matrix.
b : {(…, M,), (, M, K)} {dpnp.ndarray, usm_ndarray}
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
Ordinate or "dependent variable" values.
Returns
-------
out : {(, M,), (, M, K)} dpnp.ndarray
out : {(..., M,), (..., M, K)} dpnp.ndarray
Solution to the system `ax = b`. Returned shape is identical to `b`.
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
See Also
--------
:obj:`dpnp.dot` : Returns the dot product of two arrays.
Notes
-----
The `b` array is only treated as a shape (M,) column vector if it is
exactly 1-dimensional. In all other instances it is treated as a stack
of (M, K) matrices.
Examples
--------
>>> import dpnp as dp
Expand All @@ -1644,14 +1650,38 @@ def solve(a, b):
assert_stacked_2d(a)
assert_stacked_square(a)

if not (
a.ndim in [b.ndim, b.ndim + 1] and a.shape[:-1] == b.shape[: a.ndim - 1]
):
raise dpnp.linalg.LinAlgError(
"a must have (..., M, M) shape and b must have (..., M) "
"or (..., M, K)"
a_shape = a.shape
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
b_shape = b.shape
b_ndim = b.ndim

# compatible with numpy>=2.0
if b_ndim == 0:
raise ValueError("b must have at least one dimension")
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
if b_ndim == 1:
if a_shape[-1] != b.size:
raise ValueError(
"a must have (..., M, M) shape and b must have (M,) "
"for one-dimensional b"
)
b = dpnp.broadcast_to(b, a_shape[:-1])
return dpnp_solve(a, b)

if a_shape[-1] != b_shape[-2]:
raise ValueError(
"a must have (..., M, M) shape and b must have (..., M, K) shape"
)

# Use dpnp.broadcast_shapes() to align the resulting batch shapes
broadcasted_batch_shape = dpnp.broadcast_shapes(a_shape[:-2], b_shape[:-2])

a_broadcasted_shape = broadcasted_batch_shape + a_shape[-2:]
b_broadcasted_shape = broadcasted_batch_shape + b_shape[-2:]

if a_shape != a_broadcasted_shape:
a = dpnp.broadcast_to(a, a_broadcasted_shape)
if b_shape != b_broadcasted_shape:
b = dpnp.broadcast_to(b, b_broadcasted_shape)

return dpnp_solve(a, b)


Expand Down
34 changes: 34 additions & 0 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2694,6 +2694,36 @@ def test_solve(self, dtype):

assert_allclose(expected, result, rtol=1e-06)

@testing.with_requires("numpy>=2.0")
@pytest.mark.parametrize("dtype", get_float_complex_dtypes())
@pytest.mark.parametrize(
"a_shape, b_shape",
[
((4, 4), (2, 2, 4, 3)),
((2, 5, 5), (1, 5, 3)),
((2, 4, 4), (2, 2, 4, 2)),
((3, 2, 2), (3, 1, 2, 1)),
((2, 2, 2, 2, 2), (2,)),
((2, 2, 2, 2, 2), (2, 3)),
],
)
def test_solve_broadcast(self, a_shape, b_shape, dtype):
# Set seed_value=81 to prevent
# random generation of the input singular matrix
a_np = generate_random_numpy_array(a_shape, dtype, seed_value=81)

# Set seed_value=76 to prevent
# random generation of the input singular matrix
b_np = generate_random_numpy_array(b_shape, dtype, seed_value=76)

a_dp = inp.array(a_np)
b_dp = inp.array(b_np)

expected = numpy.linalg.solve(a_np, b_np)
result = inp.linalg.solve(a_dp, b_dp)

assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
def test_solve_nrhs_greater_n(self, dtype):
# Test checking the case when nrhs > n for
Expand Down Expand Up @@ -2800,6 +2830,10 @@ def test_solve_errors(self):
inp.linalg.LinAlgError, inp.linalg.solve, a_dp_ndim_1, b_dp
)

# b.ndim == 0
b_dp_ndim_0 = inp.array(2)
assert_raises(ValueError, inp.linalg.solve, a_dp, b_dp_ndim_0)


class TestSlogdet:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
Expand Down
30 changes: 12 additions & 18 deletions dpnp/tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,40 +2392,34 @@ def test_where(device):
ids=[device.filter_string for device in valid_devices],
)
@pytest.mark.parametrize(
"matrix, vector",
"matrix, rhs",
[
([[1, 2], [3, 5]], numpy.empty((2, 0))),
([[1, 2], [3, 5]], [1, 2]),
(
[
[[1, 1, 1], [0, 2, 5], [2, 5, -1]],
[[3, -1, 1], [1, 2, 3], [2, 3, 1]],
[[1, 4, 1], [1, 2, -2], [4, 1, 2]],
[[1, 1], [0, 2]],
[[3, -1], [1, 2]],
],
[
[[6, -4], [9, -6]],
[[15, 1], [15, 1]],
],
[[6, -4, 27], [9, -6, 15], [15, 1, 11]],
),
],
ids=[
"2D_Matrix_Empty_Vector",
"2D_Matrix_1D_Vector",
"3D_Matrix_and_Vectors",
"2D_Matrix_Empty_RHS",
"2D_Matrix_1D_RHS",
"3D_Matrix_and_3D_RHS",
],
)
def test_solve(matrix, vector, device):
def test_solve(matrix, rhs, device):
a_np = numpy.array(matrix)
b_np = numpy.array(vector)
b_np = numpy.array(rhs)

a_dp = dpnp.array(a_np, device=device)
b_dp = dpnp.array(b_np, device=device)

# In numpy 2.0 the broadcast ambiguity has been removed and now
# b is treaded as a single vector if and only if it is 1-dimensional;
# for other cases this signature must be followed
# (..., m, m), (..., m, n) -> (..., m, n)
# https://github.com/numpy/numpy/pull/25914
if a_dp.ndim > 2 and numpy.lib.NumpyVersion(numpy.__version__) >= "2.0.0":
pytest.skip("SAT-6928")

result = dpnp.linalg.solve(a_dp, b_dp)
expected = numpy.linalg.solve(a_np, b_np)
assert_dtype_allclose(result, expected)
Expand Down
30 changes: 16 additions & 14 deletions dpnp/tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,37 +1285,39 @@ def test_fftshift(self, func, usm_type):
"usm_type_matrix", list_of_usm_types, ids=list_of_usm_types
)
@pytest.mark.parametrize(
"usm_type_vector", list_of_usm_types, ids=list_of_usm_types
"usm_type_rhs", list_of_usm_types, ids=list_of_usm_types
)
@pytest.mark.parametrize(
"matrix, vector",
"matrix, rhs",
[
([[1, 2], [3, 5]], dp.empty((2, 0))),
([[1, 2], [3, 5]], numpy.empty((2, 0))),
([[1, 2], [3, 5]], [1, 2]),
(
[
[[1, 1, 1], [0, 2, 5], [2, 5, -1]],
[[3, -1, 1], [1, 2, 3], [2, 3, 1]],
[[1, 4, 1], [1, 2, -2], [4, 1, 2]],
[[1, 1], [0, 2]],
[[3, -1], [1, 2]],
],
[
[[6, -4], [9, -6]],
[[15, 1], [15, 1]],
],
[[6, -4, 27], [9, -6, 15], [15, 1, 11]],
),
],
ids=[
"2D_Matrix_Empty_Vector",
"2D_Matrix_1D_Vector",
"3D_Matrix_and_Vectors",
"2D_Matrix_Empty_RHS",
"2D_Matrix_1D_RHS",
"3D_Matrix_and_3D_RHS",
],
)
def test_solve(matrix, vector, usm_type_matrix, usm_type_vector):
def test_solve(matrix, rhs, usm_type_matrix, usm_type_rhs):
x = dp.array(matrix, usm_type=usm_type_matrix)
y = dp.array(vector, usm_type=usm_type_vector)
y = dp.array(rhs, usm_type=usm_type_rhs)
z = dp.linalg.solve(x, y)

assert x.usm_type == usm_type_matrix
assert y.usm_type == usm_type_vector
assert y.usm_type == usm_type_rhs
assert z.usm_type == du.get_coerced_usm_type(
[usm_type_matrix, usm_type_vector]
[usm_type_matrix, usm_type_rhs]
)


Expand Down
21 changes: 9 additions & 12 deletions dpnp/tests/third_party/cupy/linalg_tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def check_x(self, a_shape, b_shape, xp, dtype):
testing.assert_array_equal(b_copy, b)
return result

@testing.with_requires("numpy>=2.0")
def test_solve(self):
self.check_x((4, 4), (4,))
self.check_x((5, 5), (5, 2))
Expand All @@ -55,15 +56,9 @@ def test_solve(self):
self.check_x((0, 0), (0,))
self.check_x((0, 0), (0, 2))
self.check_x((0, 2, 2), (0, 2, 3))
# In numpy 2.0 the broadcast ambiguity has been removed and now
# b is treaded as a single vector if and only if it is 1-dimensional;
# for other cases this signature must be followed
# (..., m, m), (..., m, n) -> (..., m, n)
# https://github.com/numpy/numpy/pull/25914
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
self.check_x((2, 4, 4), (2, 4))
self.check_x((2, 3, 2, 2), (2, 3, 2))
self.check_x((0, 2, 2), (0, 2))
# Allowed since numpy 2
self.check_x((2, 3, 3), (3,))
self.check_x((2, 5, 3, 3), (3,))

def check_shape(self, a_shape, b_shape, error_types):
for xp, error_type in error_types.items():
Expand All @@ -82,6 +77,7 @@ def test_solve_singular_empty(self, xp):
# LinAlgError("Singular matrix") is not raised
return xp.linalg.solve(a, b)

@testing.with_requires("numpy>=2.0")
def test_invalid_shape(self):
linalg_errors = {
numpy: numpy.linalg.LinAlgError,
Expand All @@ -96,11 +92,12 @@ def test_invalid_shape(self):
self.check_shape((3, 3), (2,), value_errors)
self.check_shape((3, 3), (2, 2), value_errors)
self.check_shape((3, 3, 4), (3,), linalg_errors)
# Since numpy >= 2.0, this case does not raise an error
if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0":
self.check_shape((2, 3, 3), (3,), value_errors)
self.check_shape((3, 3), (0,), value_errors)
self.check_shape((0, 3, 4), (3,), linalg_errors)
# Not allowed since numpy 2.0
self.check_shape((0, 2, 2), (0, 2), value_errors)
self.check_shape((2, 4, 4), (2, 4), value_errors)
self.check_shape((2, 3, 2, 2), (2, 3, 2), value_errors)


@testing.parameterize(
Expand Down
Loading