Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update dpnp.linalg.svd() to run on CUDA #2212

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 78 additions & 37 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
@@ -475,6 +475,7 @@ def _batched_qr(a, mode="reduced"):
)


# pylint: disable=too-many-locals
def _batched_svd(
a,
uv_type,
@@ -532,29 +533,30 @@ def _batched_svd(
batch_shape_orig,
)

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m) + (batch_size,)
vt_shape = (n, n) + (batch_size,)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k) + (batch_size,)
vt_shape = (k, n) + (batch_size,)
jobu = ord("S")
jobvt = ord("S")
# Transpose if m < n:
# 1. cuSolver gesvd supports only m >= n
# 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
if m < n:
n, m = a.shape[-2:]
trans_flag = True
else:
u_shape = vt_shape = ()
jobu = ord("N")
jobvt = ord("N")
trans_flag = False

u_shape, vt_shape, s_shape, jobu, jobvt = _get_svd_shapes_and_flags(
m, n, compute_uv, full_matrices, batch_size=batch_size
)

_manager = dpu.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events

# Reorder the elements by moving the last two axes of `a` to the front
# to match fortran-like array order which is assumed by gesvd.
a = dpnp.moveaxis(a, (-2, -1), (0, 1))
if trans_flag:
# Transpose axes for cuSolver and to optimize reduction
# to bidiagonal form
a = dpnp.moveaxis(a, (-1, -2), (0, 1))
else:
a = dpnp.moveaxis(a, (-2, -1), (0, 1))

# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array
# as input.
@@ -583,7 +585,7 @@ def _batched_svd(
sycl_queue=exec_q,
)
s_h = dpnp.empty(
(batch_size,) + (k,),
s_shape,
dtype=s_type,
order="C",
usm_type=usm_type,
@@ -607,16 +609,23 @@ def _batched_svd(
# gesvd call writes `u_h` and `vt_h` in Fortran order;
# reorder the axes to match C order by moving the last axis
# to the front
u = dpnp.moveaxis(u_h, -1, 0)
vt = dpnp.moveaxis(vt_h, -1, 0)
if trans_flag:
# Transpose axes to restore U and V^T for the original matrix
u = dpnp.moveaxis(u_h, (0, -1), (-1, 0))
vt = dpnp.moveaxis(vt_h, (0, -1), (-1, 0))
else:
u = dpnp.moveaxis(u_h, -1, 0)
vt = dpnp.moveaxis(vt_h, -1, 0)

if a_ndim > 3:
u = u.reshape(batch_shape_orig + u.shape[-2:])
vt = vt.reshape(batch_shape_orig + vt.shape[-2:])
# dpnp.moveaxis can make the array non-contiguous if it is not 2D
# Convert to contiguous to align with NumPy
u = dpnp.ascontiguousarray(u)
vt = dpnp.ascontiguousarray(vt)
return u, s, vt
# Swap `u` and `vt` for transposed input to restore correct order
return (vt, s, u) if trans_flag else (u, s, vt)
return s


@@ -759,6 +768,36 @@ def _common_inexact_type(default_dtype, *dtypes):
return dpnp.result_type(*inexact_dtypes)


def _get_svd_shapes_and_flags(m, n, compute_uv, full_matrices, batch_size=None):
"""Return the shapes and flags for SVD computations."""

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m)
vt_shape = (n, n)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k)
vt_shape = (k, n)
jobu = ord("S")
jobvt = ord("S")
else:
u_shape = vt_shape = ()
jobu = ord("N")
jobvt = ord("N")

s_shape = (k,)
if batch_size is not None:
if compute_uv:
u_shape += (batch_size,)
vt_shape += (batch_size,)
s_shape = (batch_size,) + s_shape

return u_shape, vt_shape, s_shape, jobu, jobvt


def _hermitian_svd(a, compute_uv):
"""
_hermitian_svd(a, compute_uv)
@@ -2695,6 +2734,16 @@ def dpnp_svd(
a, uv_type, s_type, full_matrices, compute_uv, exec_q, usm_type
)

# Transpose if m < n:
# 1. cuSolver gesvd supports only m >= n
# 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
if m < n:
n, m = a.shape
a = a.transpose()
trans_flag = True
else:
trans_flag = False

# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
# Allocate 'F' order memory for dpnp arrays to comply with
# these requirements.
@@ -2716,22 +2765,9 @@ def dpnp_svd(
)
_manager.add_event_pair(ht_ev, copy_ev)

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m)
vt_shape = (n, n)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k)
vt_shape = (k, n)
jobu = ord("S")
jobvt = ord("S")
else:
u_shape = vt_shape = ()
jobu = ord("N")
jobvt = ord("N")
u_shape, vt_shape, s_shape, jobu, jobvt = _get_svd_shapes_and_flags(
m, n, compute_uv, full_matrices
)

# oneMKL LAPACK assumes fortran-like array as input.
# Allocate 'F' order memory for dpnp output arrays to comply with
@@ -2746,7 +2782,7 @@ def dpnp_svd(
shape=vt_shape,
order="F",
)
s_h = dpnp.empty_like(a_h, shape=(k,), dtype=s_type)
s_h = dpnp.empty_like(a_h, shape=s_shape, dtype=s_type)

ht_ev, gesvd_ev = li._gesvd(
exec_q,
@@ -2761,6 +2797,11 @@ def dpnp_svd(
_manager.add_event_pair(ht_ev, gesvd_ev)

if compute_uv:
# Transposing the input matrix swaps the roles of U and Vt:
# For A^T = V S^T U^T, `u_h` becomes V and `vt_h` becomes U^T.
# Transpose and swap them back to restore correct order for A.
if trans_flag:
return vt_h.T, s_h, u_h.T
# gesvd call writes `u_h` and `vt_h` in Fortran order;
# Convert to contiguous to align with NumPy
u_h = dpnp.ascontiguousarray(u_h)
16 changes: 0 additions & 16 deletions dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@
from dpnp.tests.helper import (
has_support_aspect64,
is_cpu_device,
is_win_platform,
)
from dpnp.tests.third_party.cupy import testing
from dpnp.tests.third_party.cupy.testing import _condition
@@ -280,12 +279,6 @@ def test_svd_rank2_empty_array_compute_uv_false(self, xp):
array, full_matrices=self.full_matrices, compute_uv=False
)

# The issue was expected to be resolved once CMPLRLLVM-53771 is available,
# which has to be included in DPC++ 2024.1.0, but problem still exists
# on Windows
@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
)
@_condition.repeat(3, 10)
def test_svd_rank3(self):
self.check_usv((2, 3, 4))
@@ -295,9 +288,6 @@ def test_svd_rank3(self):
self.check_usv((2, 4, 3))
self.check_usv((2, 32, 32))

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank3_loop(self):
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
@@ -345,9 +335,6 @@ def test_svd_rank3_empty_array_compute_uv_false2(self, xp):
array, full_matrices=self.full_matrices, compute_uv=False
)

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank4(self):
self.check_usv((2, 2, 3, 4))
@@ -357,9 +344,6 @@ def test_svd_rank4(self):
self.check_usv((2, 2, 4, 3))
self.check_usv((2, 2, 32, 32))

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank4_loop(self):
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)