Skip to content

Commit

Permalink
Update dpnp.linalg.svd() to run on CUDA (IntelPython#2212)
Browse files Browse the repository at this point in the history
This PR suggests updating `dpnp.linagl.svd()` implementation to support
running on CUDA devices.
Since cuSolver gesvd only supports m>=n the previous implementation crashed with `Segmentation fault (core
dumped)`

This suggests adding checks for `m>=n` otherwise transpose the input
array.

Passing the transposed array to `oneapi::mkl::lapack::gesvd` increases
the performance of `dpnp.linalg.svd()` due to the reducing a matrix with
`m >= n` to bidiagonal form (inside `lapack::gesvd`) is more efficient
  • Loading branch information
vlad-perevezentsev authored Dec 6, 2024
1 parent e0c9cf1 commit 4875e59
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 53 deletions.
115 changes: 78 additions & 37 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def _batched_qr(a, mode="reduced"):
)


# pylint: disable=too-many-locals
def _batched_svd(
a,
uv_type,
Expand Down Expand Up @@ -532,29 +533,30 @@ def _batched_svd(
batch_shape_orig,
)

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m) + (batch_size,)
vt_shape = (n, n) + (batch_size,)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k) + (batch_size,)
vt_shape = (k, n) + (batch_size,)
jobu = ord("S")
jobvt = ord("S")
# Transpose if m < n:
# 1. cuSolver gesvd supports only m >= n
# 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
if m < n:
n, m = a.shape[-2:]
trans_flag = True
else:
u_shape = vt_shape = ()
jobu = ord("N")
jobvt = ord("N")
trans_flag = False

u_shape, vt_shape, s_shape, jobu, jobvt = _get_svd_shapes_and_flags(
m, n, compute_uv, full_matrices, batch_size=batch_size
)

_manager = dpu.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events

# Reorder the elements by moving the last two axes of `a` to the front
# to match fortran-like array order which is assumed by gesvd.
a = dpnp.moveaxis(a, (-2, -1), (0, 1))
if trans_flag:
# Transpose axes for cuSolver and to optimize reduction
# to bidiagonal form
a = dpnp.moveaxis(a, (-1, -2), (0, 1))
else:
a = dpnp.moveaxis(a, (-2, -1), (0, 1))

# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array
# as input.
Expand Down Expand Up @@ -583,7 +585,7 @@ def _batched_svd(
sycl_queue=exec_q,
)
s_h = dpnp.empty(
(batch_size,) + (k,),
s_shape,
dtype=s_type,
order="C",
usm_type=usm_type,
Expand All @@ -607,16 +609,23 @@ def _batched_svd(
# gesvd call writes `u_h` and `vt_h` in Fortran order;
# reorder the axes to match C order by moving the last axis
# to the front
u = dpnp.moveaxis(u_h, -1, 0)
vt = dpnp.moveaxis(vt_h, -1, 0)
if trans_flag:
# Transpose axes to restore U and V^T for the original matrix
u = dpnp.moveaxis(u_h, (0, -1), (-1, 0))
vt = dpnp.moveaxis(vt_h, (0, -1), (-1, 0))
else:
u = dpnp.moveaxis(u_h, -1, 0)
vt = dpnp.moveaxis(vt_h, -1, 0)

if a_ndim > 3:
u = u.reshape(batch_shape_orig + u.shape[-2:])
vt = vt.reshape(batch_shape_orig + vt.shape[-2:])
# dpnp.moveaxis can make the array non-contiguous if it is not 2D
# Convert to contiguous to align with NumPy
u = dpnp.ascontiguousarray(u)
vt = dpnp.ascontiguousarray(vt)
return u, s, vt
# Swap `u` and `vt` for transposed input to restore correct order
return (vt, s, u) if trans_flag else (u, s, vt)
return s


Expand Down Expand Up @@ -759,6 +768,36 @@ def _common_inexact_type(default_dtype, *dtypes):
return dpnp.result_type(*inexact_dtypes)


def _get_svd_shapes_and_flags(m, n, compute_uv, full_matrices, batch_size=None):
"""Return the shapes and flags for SVD computations."""

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m)
vt_shape = (n, n)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k)
vt_shape = (k, n)
jobu = ord("S")
jobvt = ord("S")
else:
u_shape = vt_shape = ()
jobu = ord("N")
jobvt = ord("N")

s_shape = (k,)
if batch_size is not None:
if compute_uv:
u_shape += (batch_size,)
vt_shape += (batch_size,)
s_shape = (batch_size,) + s_shape

return u_shape, vt_shape, s_shape, jobu, jobvt


def _hermitian_svd(a, compute_uv):
"""
_hermitian_svd(a, compute_uv)
Expand Down Expand Up @@ -2695,6 +2734,16 @@ def dpnp_svd(
a, uv_type, s_type, full_matrices, compute_uv, exec_q, usm_type
)

# Transpose if m < n:
# 1. cuSolver gesvd supports only m >= n
# 2. Reducing a matrix with m >= n to bidiagonal form is more efficient
if m < n:
n, m = a.shape
a = a.transpose()
trans_flag = True
else:
trans_flag = False

# oneMKL LAPACK gesvd destroys `a` and assumes fortran-like array as input.
# Allocate 'F' order memory for dpnp arrays to comply with
# these requirements.
Expand All @@ -2716,22 +2765,9 @@ def dpnp_svd(
)
_manager.add_event_pair(ht_ev, copy_ev)

k = min(m, n)
if compute_uv:
if full_matrices:
u_shape = (m, m)
vt_shape = (n, n)
jobu = ord("A")
jobvt = ord("A")
else:
u_shape = (m, k)
vt_shape = (k, n)
jobu = ord("S")
jobvt = ord("S")
else:
u_shape = vt_shape = ()
jobu = ord("N")
jobvt = ord("N")
u_shape, vt_shape, s_shape, jobu, jobvt = _get_svd_shapes_and_flags(
m, n, compute_uv, full_matrices
)

# oneMKL LAPACK assumes fortran-like array as input.
# Allocate 'F' order memory for dpnp output arrays to comply with
Expand All @@ -2746,7 +2782,7 @@ def dpnp_svd(
shape=vt_shape,
order="F",
)
s_h = dpnp.empty_like(a_h, shape=(k,), dtype=s_type)
s_h = dpnp.empty_like(a_h, shape=s_shape, dtype=s_type)

ht_ev, gesvd_ev = li._gesvd(
exec_q,
Expand All @@ -2761,6 +2797,11 @@ def dpnp_svd(
_manager.add_event_pair(ht_ev, gesvd_ev)

if compute_uv:
# Transposing the input matrix swaps the roles of U and Vt:
# For A^T = V S^T U^T, `u_h` becomes V and `vt_h` becomes U^T.
# Transpose and swap them back to restore correct order for A.
if trans_flag:
return vt_h.T, s_h, u_h.T
# gesvd call writes `u_h` and `vt_h` in Fortran order;
# Convert to contiguous to align with NumPy
u_h = dpnp.ascontiguousarray(u_h)
Expand Down
16 changes: 0 additions & 16 deletions dpnp/tests/third_party/cupy/linalg_tests/test_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dpnp.tests.helper import (
has_support_aspect64,
is_cpu_device,
is_win_platform,
)
from dpnp.tests.third_party.cupy import testing
from dpnp.tests.third_party.cupy.testing import _condition
Expand Down Expand Up @@ -280,12 +279,6 @@ def test_svd_rank2_empty_array_compute_uv_false(self, xp):
array, full_matrices=self.full_matrices, compute_uv=False
)

# The issue was expected to be resolved once CMPLRLLVM-53771 is available,
# which has to be included in DPC++ 2024.1.0, but problem still exists
# on Windows
@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank3(self):
self.check_usv((2, 3, 4))
Expand All @@ -295,9 +288,6 @@ def test_svd_rank3(self):
self.check_usv((2, 4, 3))
self.check_usv((2, 32, 32))

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank3_loop(self):
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
Expand Down Expand Up @@ -345,9 +335,6 @@ def test_svd_rank3_empty_array_compute_uv_false2(self, xp):
array, full_matrices=self.full_matrices, compute_uv=False
)

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank4(self):
self.check_usv((2, 2, 3, 4))
Expand All @@ -357,9 +344,6 @@ def test_svd_rank4(self):
self.check_usv((2, 2, 4, 3))
self.check_usv((2, 2, 32, 32))

@pytest.mark.skipif(
is_cpu_device() and is_win_platform(), reason="SAT-7145"
)
@_condition.repeat(3, 10)
def test_svd_rank4_loop(self):
# This tests the loop-based batched gesvd on CUDA (_gesvd_batched)
Expand Down

0 comments on commit 4875e59

Please sign in to comment.