Skip to content

Commit

Permalink
Add implementation of dpnp.linalg.svdvals() (#2094)
Browse files Browse the repository at this point in the history
* Add implementation of dpnp.linalg.svdvals

* Add TestSvdvals to test_linalg.py

* Add dpnp.linalg.svdvals to docs

* Update doc/reference/linalg.rst

---------

Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
vlad-perevezentsev and antonwolfy authored Oct 9, 2024
1 parent c4677e9 commit 7effabc
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
2 changes: 2 additions & 0 deletions doc/reference/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ Decompositions
:nosignatures:

dpnp.linalg.cholesky
dpnp.linalg.outer
dpnp.linalg.qr
dpnp.linalg.svd
dpnp.linalg.svdvals

Matrix eigenvalues
------------------
Expand Down
57 changes: 57 additions & 0 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"qr",
"solve",
"svd",
"svdvals",
"slogdet",
"tensorinv",
"tensorsolve",
Expand Down Expand Up @@ -1315,6 +1316,62 @@ def svd(a, full_matrices=True, compute_uv=True, hermitian=False):
return dpnp_svd(a, full_matrices, compute_uv, hermitian)


def svdvals(x, /):
"""
Returns the singular values of a matrix (or a stack of matrices) `x`.
When `x` is a stack of matrices, the function will compute
the singular values for each matrix in the stack.
Calling ``dpnp.linalg.svdvals(x)`` to get singular values is the same as
``dpnp.linalg.svd(x, compute_uv=False, hermitian=False)``.
For full documentation refer to :obj:`numpy.linalg.svdvals`.
Parameters
----------
x : (..., M, N) {dpnp.ndarray, usm_ndarray}
Input array with ``x.ndim >= 2`` and whose last two dimensions
form matrices on which to perform singular value decomposition.
Returns
-------
out : (..., K) dpnp.ndarray
Vector(s) of singular values of length K, where K = min(M, N).
See Also
--------
:obj:`dpnp.linalg.svd` : Compute the singular value decomposition.
Examples
--------
>>> import dpnp as np
>>> a = np.array([[3, 0], [0, 4]])
>>> np.linalg.svdvals(a)
array([4., 3.])
This is equivalent to calling:
>>> np.linalg.svd(a, compute_uv=False, hermitian=False)
array([4., 3.])
Stack of matrices:
>>> b = np.array([[[6, 0], [0, 8]], [[9, 0], [0, 12]]])
>>> np.linalg.svdvals(b)
array([[ 8., 6.],
[12., 9.]])
"""

dpnp.check_supported_arrays_type(x)
assert_stacked_2d(x)

return dpnp_svd(x, full_matrices=True, compute_uv=False, hermitian=False)


def slogdet(a):
"""
Compute the sign and (natural) logarithm of the determinant of an array.
Expand Down
44 changes: 44 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2993,6 +2993,50 @@ def test_svd_errors(self):
assert_raises(inp.linalg.LinAlgError, inp.linalg.svd, a_dp_ndim_1)


# numpy.linalg.svdvals() is available since numpy >= 2.0
@testing.with_requires("numpy>=2.0")
class TestSvdvals:
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize(
"shape",
[(3, 5), (4, 2), (2, 3, 3), (3, 5, 2)],
ids=["(3,5)", "(4,2)", "(2,3,3)", "(3,5,2)"],
)
def test_svdvals(self, dtype, shape):
a = numpy.arange(numpy.prod(shape), dtype=dtype).reshape(shape)
dp_a = inp.array(a)

expected = numpy.linalg.svdvals(a)
result = inp.linalg.svdvals(dp_a)

assert_dtype_allclose(result, expected)

@pytest.mark.parametrize(
"shape",
[(0, 0), (1, 0, 0), (0, 2, 2)],
ids=["(0,0)", "(1,0,0)", "(0,2,2)"],
)
def test_svdvals_empty(self, shape):
a = generate_random_numpy_array(shape, inp.default_float_type())
dp_a = inp.array(a)

expected = numpy.linalg.svdvals(a)
result = inp.linalg.svdvals(dp_a)

assert_dtype_allclose(result, expected)

def test_svdvals_errors(self):
a_dp = inp.array([[1, 2], [3, 4]], dtype="float32")

# unsupported type
a_np = inp.asnumpy(a_dp)
assert_raises(TypeError, inp.linalg.svdvals, a_np)

# a.ndim < 2
a_dp_ndim_1 = a_dp.flatten()
assert_raises(inp.linalg.LinAlgError, inp.linalg.svdvals, a_dp_ndim_1)


class TestPinv:
def get_tol(self, dtype):
tol = 1e-06
Expand Down

0 comments on commit 7effabc

Please sign in to comment.