Skip to content

Update dpnp.linalg.svd() to run on CUDA #2212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 78 additions & 37 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def _batched_qr(a, mode="reduced"):
)


# pylint: disable=too-many-locals
def _batched_svd(
a,
uv_type,
Expand Down Expand Up @@ -532,29 +533,30 @@ def _batched_svd(
batch_shape_orig,
)

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m) + (batch_size,)
vt_shape = (n, n) + (batch_size,)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k) + (batch_size,)
vt_shape = (k, n) + (batch_size,)
jobu = ord("S")
jobvt = ord("S")
# Transpose if m < n:
# 1. cuSolver gesvd supports only m >= n
# 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
if m < n:
n, m = a.shape[-2:]
trans_flag = True
else:
u_shape = vt_shape = ()
jobu = ord("N")
jobvt = ord("N")
trans_flag = False

u_shape, vt_shape, s_shape, jobu, jobvt = _get_svd_shapes_and_flags(
m, n, compute_uv, full_matrices, batch_size=batch_size
)

_manager = dpu.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events

# Reorder the elements by moving the last two axes of `a` to the front
# to match fortran-like array order which is assumed by gesvd.
a = dpnp.moveaxis(a, (-2, -1), (0, 1))
if trans_flag:
# Transpose axes for cuSolver and to optimize reduction
# to bidiagonal form
a = dpnp.moveaxis(a, (-1, -2), (0, 1))
else:
a = dpnp.moveaxis(a, (-2, -1), (0, 1))

# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array
# as input.
Expand Down Expand Up @@ -583,7 +585,7 @@ def _batched_svd(
sycl_queue=exec_q,
)
s_h = dpnp.empty(
(batch_size,) + (k,),
s_shape,
dtype=s_type,
order="C",
usm_type=usm_type,
Expand All @@ -607,16 +609,23 @@ def _batched_svd(
# gesvd call writes `u_h` and `vt_h` in Fortran order;
# reorder the axes to match C order by moving the last axis
# to the front
u = dpnp.moveaxis(u_h, -1, 0)
vt = dpnp.moveaxis(vt_h, -1, 0)
if trans_flag:
# Transpose axes to restore U and V^T for the original matrix
u = dpnp.moveaxis(u_h, (0, -1), (-1, 0))
vt = dpnp.moveaxis(vt_h, (0, -1), (-1, 0))
else:
u = dpnp.moveaxis(u_h, -1, 0)
vt = dpnp.moveaxis(vt_h, -1, 0)

if a_ndim > 3:
u = u.reshape(batch_shape_orig + u.shape[-2:])
vt = vt.reshape(batch_shape_orig + vt.shape[-2:])
# dpnp.moveaxis can make the array non-contiguous if it is not 2D
# Convert to contiguous to align with NumPy
u = dpnp.ascontiguousarray(u)
vt = dpnp.ascontiguousarray(vt)
return u, s, vt
# Swap `u` and `vt` for transposed input to restore correct order
return (vt, s, u) if trans_flag else (u, s, vt)
return s


Expand Down Expand Up @@ -759,6 +768,36 @@ def _common_inexact_type(default_dtype, *dtypes):
return dpnp.result_type(*inexact_dtypes)


def _get_svd_shapes_and_flags(m, n, compute_uv, full_matrices, batch_size=None):
"""Return the shapes and flags for SVD computations."""

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m)
vt_shape = (n, n)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k)
vt_shape = (k, n)
jobu = ord("S")
jobvt = ord("S")
else:
u_shape = vt_shape = ()
jobu = ord("N")
jobvt = ord("N")

s_shape = (k,)
if batch_size is not None:
if compute_uv:
u_shape += (batch_size,)
vt_shape += (batch_size,)
s_shape = (batch_size,) + s_shape

return u_shape, vt_shape, s_shape, jobu, jobvt


def _hermitian_svd(a, compute_uv):
"""
_hermitian_svd(a, compute_uv)
Expand Down Expand Up @@ -2695,6 +2734,16 @@ def dpnp_svd(
a, uv_type, s_type, full_matrices, compute_uv, exec_q, usm_type
)

# Transpose if m < n:
# 1. cuSolver gesvd supports only m >= n
# 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
if m < n:
n, m = a.shape
a = a.transpose()
trans_flag = True
else:
trans_flag = False

# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
# Allocate 'F' order memory for dpnp arrays to comply with
# these requirements.
Expand All @@ -2716,22 +2765,9 @@ def dpnp_svd(
)
_manager.add_event_pair(ht_ev, copy_ev)

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m)
vt_shape = (n, n)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k)
vt_shape = (k, n)
jobu = ord("S")
jobvt = ord("S")
else:
u_shape = vt_shape = ()
jobu = ord("N")
jobvt = ord("N")
u_shape, vt_shape, s_shape, jobu, jobvt = _get_svd_shapes_and_flags(
m, n, compute_uv, full_matrices
)

# oneMKL LAPACK assumes fortran-like array as input.
# Allocate 'F' order memory for dpnp output arrays to comply with
Expand All @@ -2746,7 +2782,7 @@ def dpnp_svd(
shape=vt_shape,
order="F",
)
s_h = dpnp.empty_like(a_h, shape=(k,), dtype=s_type)
s_h = dpnp.empty_like(a_h, shape=s_shape, dtype=s_type)

ht_ev, gesvd_ev = li._gesvd(
exec_q,
Expand All @@ -2761,6 +2797,11 @@ def dpnp_svd(
_manager.add_event_pair(ht_ev, gesvd_ev)

if compute_uv:
# Transposing the input matrix swaps the roles of U and Vt:
# For A^T = V S^T U^T, `u_h` becomes V and `vt_h` becomes U^T.
# Transpose and swap them back to restore correct order for A.
if trans_flag:
return vt_h.T, s_h, u_h.T
# gesvd call writes `u_h` and `vt_h` in Fortran order;
# Convert to contiguous to align with NumPy
u_h = dpnp.ascontiguousarray(u_h)
Expand Down
16 changes: 0 additions & 16 deletions dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dpnp.tests.helper import (
has_support_aspect64,
is_cpu_device,
is_win_platform,
)
from dpnp.tests.third_party.cupy import testing
from dpnp.tests.third_party.cupy.testing import _condition
Expand Down Expand Up @@ -280,12 +279,6 @@ def test_svd_rank2_empty_array_compute_uv_false(self, xp):
array, full_matrices=self.full_matrices, compute_uv=False
)

# The issue was expected to be resolved once CMPLRLLVM-53771 is available,
# which has to be included in DPC++ 2024.1.0, but problem still exists
# on Windows
@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank3(self):
self.check_usv((2, 3, 4))
Expand All @@ -295,9 +288,6 @@ def test_svd_rank3(self):
self.check_usv((2, 4, 3))
self.check_usv((2, 32, 32))

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank3_loop(self):
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
Expand Down Expand Up @@ -345,9 +335,6 @@ def test_svd_rank3_empty_array_compute_uv_false2(self, xp):
array, full_matrices=self.full_matrices, compute_uv=False
)

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank4(self):
self.check_usv((2, 2, 3, 4))
Expand All @@ -357,9 +344,6 @@ def test_svd_rank4(self):
self.check_usv((2, 2, 4, 3))
self.check_usv((2, 2, 32, 32))

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank4_loop(self):
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
Expand Down
Loading