Skip to content

Commit

Permalink
updating tests - Part1 (IntelPython#2210)
Browse files Browse the repository at this point in the history
This is part 1 of a series of PRs in which the tests are refactored. In
this PR, `test_linalg.py`, `test_product.py`, `test_statistics.py`,
`test_fft.py`, and `test_sort.py` are updated.
  • Loading branch information
vtavana authored Dec 8, 2024
1 parent 4607833 commit 6dc39f9
Show file tree
Hide file tree
Showing 9 changed files with 1,111 additions and 1,280 deletions.
63 changes: 2 additions & 61 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@
"""


import numpy
from dpctl.tensor._numpy_helper import normalize_axis_tuple

import dpnp

Expand All @@ -48,6 +46,7 @@
dpnp_dot,
dpnp_kron,
dpnp_matmul,
dpnp_tensordot,
dpnp_vecdot,
)

Expand Down Expand Up @@ -1047,65 +1046,7 @@ def tensordot(a, b, axes=2):
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b)

try:
iter(axes)
except Exception as e: # pylint: disable=broad-exception-caught
if not isinstance(axes, int):
raise TypeError("Axes must be an integer.") from e
if axes < 0:
raise ValueError("Axes must be a non-negative integer.") from e
axes_a = tuple(range(-axes, 0))
axes_b = tuple(range(0, axes))
else:
if len(axes) != 2:
raise ValueError("Axes must consist of two sequences.")

axes_a, axes_b = axes
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b

if len(axes_a) != len(axes_b):
raise ValueError("Axes length mismatch.")

# Make the axes non-negative
a_ndim = a.ndim
b_ndim = b.ndim
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis_a")
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis_b")

if a.ndim == 0 or b.ndim == 0:
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b)

a_shape = a.shape
b_shape = b.shape
for axis_a, axis_b in zip(axes_a, axes_b):
if a_shape[axis_a] != b_shape[axis_b]:
raise ValueError(
"shape of input arrays is not similar at requested axes."
)

# Move the axes to sum over, to the end of "a"
not_in = tuple(k for k in range(a_ndim) if k not in axes_a)
newaxes_a = not_in + axes_a
n1 = int(numpy.prod([a_shape[ax] for ax in not_in]))
n2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
newshape_a = (n1, n2)
olda = [a_shape[axis] for axis in not_in]

# Move the axes to sum over, to the front of "b"
not_in = tuple(k for k in range(b_ndim) if k not in axes_b)
newaxes_b = tuple(axes_b + not_in)
n1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
n2 = int(numpy.prod([b_shape[ax] for ax in not_in]))
newshape_b = (n1, n2)
oldb = [b_shape[axis] for axis in not_in]

at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
res = dpnp.matmul(at, bt)

return res.reshape(olda + oldb)
return dpnp_tensordot(a, b, axes=axes)


def vdot(a, b):
Expand Down
17 changes: 11 additions & 6 deletions dpnp/dpnp_iface_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,18 @@ def _wrap_sort_argsort(

if order is not None:
raise NotImplementedError(
"order keyword argument is only supported with its default value."
)
if kind is not None and stable is not None:
raise ValueError(
"`kind` and `stable` parameters can't be provided at the same time."
" Use only one of them."
"`order` keyword argument is only supported with its default value."
)
if stable is not None:
if stable not in [True, False]:
raise ValueError(
"`stable` parameter should be None, True, or False."
)
if kind is not None:
raise ValueError(
"`kind` and `stable` parameters can't be provided at"
" the same time. Use only one of them."
)

usm_a = dpnp.get_usm_ndarray(a)
if axis is None:
Expand Down
73 changes: 72 additions & 1 deletion dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,14 @@
from dpnp.dpnp_array import dpnp_array
from dpnp.dpnp_utils import get_usm_allocations

__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_kron", "dpnp_matmul", "dpnp_vecdot"]
__all__ = [
"dpnp_cross",
"dpnp_dot",
"dpnp_kron",
"dpnp_matmul",
"dpnp_tensordot",
"dpnp_vecdot",
]


def _compute_res_dtype(*arrays, sycl_queue, dtype=None, casting="no"):
Expand Down Expand Up @@ -974,6 +981,70 @@ def dpnp_matmul(
return result


def dpnp_tensordot(a, b, axes=2):
"""Tensor dot product of two arrays."""

try:
iter(axes)
except Exception as e: # pylint: disable=broad-exception-caught
if not isinstance(axes, int):
raise TypeError("Axes must be an integer.") from e
if axes < 0:
raise ValueError("Axes must be a non-negative integer.") from e
axes_a = tuple(range(-axes, 0))
axes_b = tuple(range(0, axes))
else:
if len(axes) != 2:
raise ValueError("Axes must consist of two sequences.")

axes_a, axes_b = axes
axes_a = (axes_a,) if dpnp.isscalar(axes_a) else axes_a
axes_b = (axes_b,) if dpnp.isscalar(axes_b) else axes_b

if len(axes_a) != len(axes_b):
raise ValueError("Axes length mismatch.")

# Make the axes non-negative
a_ndim = a.ndim
b_ndim = b.ndim
axes_a = normalize_axis_tuple(axes_a, a_ndim, "axis_a")
axes_b = normalize_axis_tuple(axes_b, b_ndim, "axis_b")

if a.ndim == 0 or b.ndim == 0:
# TODO: use specific scalar-vector kernel
return dpnp.multiply(a, b)

a_shape = a.shape
b_shape = b.shape
for axis_a, axis_b in zip(axes_a, axes_b):
if a_shape[axis_a] != b_shape[axis_b]:
raise ValueError(
"shape of input arrays is not similar at requested axes."
)

# Move the axes to sum over, to the end of "a"
not_in = tuple(k for k in range(a_ndim) if k not in axes_a)
newaxes_a = not_in + axes_a
n1 = int(numpy.prod([a_shape[ax] for ax in not_in]))
n2 = int(numpy.prod([a_shape[ax] for ax in axes_a]))
newshape_a = (n1, n2)
olda = [a_shape[axis] for axis in not_in]

# Move the axes to sum over, to the front of "b"
not_in = tuple(k for k in range(b_ndim) if k not in axes_b)
newaxes_b = tuple(axes_b + not_in)
n1 = int(numpy.prod([b_shape[ax] for ax in axes_b]))
n2 = int(numpy.prod([b_shape[ax] for ax in not_in]))
newshape_b = (n1, n2)
oldb = [b_shape[axis] for axis in not_in]

at = dpnp.transpose(a, newaxes_a).reshape(newshape_a)
bt = dpnp.transpose(b, newaxes_b).reshape(newshape_b)
res = dpnp.matmul(at, bt)

return res.reshape(olda + oldb)


def dpnp_vecdot(
x1,
x2,
Expand Down
Loading

0 comments on commit 6dc39f9

Please sign in to comment.