Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for axes keyword to dpnp.matmul #1705

Merged
merged 10 commits into from
Feb 16, 2024
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace mkl_blas = oneapi::mkl::blas;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
oneapi::mkl::transpose,
oneapi::mkl::transpose,
const std::int64_t,
Expand All @@ -64,7 +64,7 @@ static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
[dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event gemm_impl(sycl::queue exec_q,
static sycl::event gemm_impl(sycl::queue &exec_q,
oneapi::mkl::transpose transA,
oneapi::mkl::transpose transB,
const std::int64_t m,
Expand Down Expand Up @@ -130,7 +130,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
gemm(sycl::queue exec_q,
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ namespace ext
namespace blas
{
extern std::pair<sycl::event, sycl::event>
gemm(sycl::queue exec_q,
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
gemm_batch(sycl::queue exec_q,
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
sycl::queue,
sycl::queue &,
const std::int64_t,
const std::int64_t,
const std::int64_t,
Expand All @@ -69,7 +69,7 @@ static gemm_batch_impl_fn_ptr_t
gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event gemm_batch_impl(sycl::queue exec_q,
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
const std::int64_t m,
const std::int64_t n,
const std::int64_t k,
Expand Down Expand Up @@ -145,7 +145,7 @@ static sycl::event gemm_batch_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
gemm_batch(sycl::queue exec_q,
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
56 changes: 52 additions & 4 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,53 @@ def matmul(
order="K",
dtype=None,
subok=True,
signature=None,
extobj=None,
axes=None,
axis=None,
):
"""
Matrix product of two arrays.

For full documentation refer to :obj:`numpy.matmul`.

Parameters
----------
x1 : {dpnp_array, usm_ndarray}
First input array.
x2 : {dpnp_array, usm_ndarray}
Second input array.
out : {dpnp.ndarray, usm_ndarray}, optional
Alternative output array in which to place the result. It must have
a shape that matches the signature `(n,k),(k,m)->(n,m)` but the type
(of the calculated values) will be cast if necessary. Default: ``None``.
dtype : dtype, optional
Type to use in computing the matrix product. By default, the returned
array will have data type that is determined by considering
Promotion Type Rule and device capabilities.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
Controls what kind of data casting may occur. Default: ``"same_kind"``.
order : {"C", "F", "A", "K", None}, optional
Memory layout of the newly output array, if parameter `out` is ``None``.
Default: "K".
axes : list of tuples, optional
A list of tuples with indices of axes the matrix product should operate on.
For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base elements
are 2d matrices and these are taken to be stored in the two last axes of each
argument. The corresponding axes keyword would be [(-2, -1), (-2, -1), (-2, -1)].
Default: ``None``.

Returns
-------
out : dpnp.ndarray
Returns the matrix product of the inputs.
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.

Limitations
-----------
Input arrays and parameter `out` are supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Keyword argument `subok` is currently unsupported.
Input array data types are limited by supported DPNP :ref:`Data types`.
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
only supported with their default value.
Otherwise ``NotImplementedError`` exception will be raised.

See Also
--------
Expand Down Expand Up @@ -338,6 +373,18 @@ def matmul(
raise NotImplementedError(
"subok keyword argument is only supported by its default value."
)
elif signature is not None:
raise NotImplementedError(
"signature keyword argument is only supported by its default value."
)
elif extobj is not None:
raise NotImplementedError(
"extobj keyword argument is only supported by its default value."
)
elif axis is not None:
raise NotImplementedError(
"axis keyword argument is only supported by its default value."
)
else:
return dpnp_matmul(
x1,
Expand All @@ -346,6 +393,7 @@ def matmul(
casting=casting,
order=order,
dtype=dtype,
axes=axes,
)


Expand Down
147 changes: 129 additions & 18 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
import numpy
from numpy.core.numeric import normalize_axis_tuple

import dpnp
import dpnp.backend.extensions.blas._blas_impl as bi
Expand All @@ -43,7 +44,7 @@ def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue):
If `out` is not ``None`` and its features match the specified `shape`, `dtype,
`usm_type`, and `sycl_queue` and it is C-contiguous or F-contiguous and
does not have any memory overlap with `x1` and `x2`, `out` itself is returned.
If these conditions are not statisfied, an empty array is returned with the
If these conditions are not satisfied, an empty array is returned with the
specified `shape`, `dtype, `usm_type`, and `sycl_queue`.
"""

Expand Down Expand Up @@ -116,21 +117,9 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
x2_strides = x2.strides
res_strides = res.strides

# when shape along any particular dimension is 1,
# the stride along that dimension is not a
# meaningful number and is undefined. Here, we
# standardizing strides before continuing,
# setting stride to 0 if the shape along that axis is <=1
if x1_is_2D:
x1_strides = tuple(
str_i if sh_i > 1 else 0
for sh_i, str_i in zip(x1.shape, x1_strides)
)
if x2_is_2D:
x2_strides = tuple(
str_i if sh_i > 1 else 0
for sh_i, str_i in zip(x2.shape, x2_strides)
)
# need to standardize to use in ti._contract_iter2
x1_strides = _standardize_strides(x1_strides, x1_is_2D, x1.shape, x1.ndim)
x2_strides = _standardize_strides(x2_strides, x2_is_2D, x2.shape, x2.ndim)

batch_size = res.shape[:-2][0]
stridea = x1_strides[0]
Expand Down Expand Up @@ -220,6 +209,92 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
return op_dtype, res_dtype


def _standardize_strides(strides, inherently_2D, shape, ndim):
"""
Standardizing the strides.
vtavana marked this conversation as resolved.
Show resolved Hide resolved

When shape of an array along any particular dimension is 1, the stride
along that dimension is undefined. This functions standardize the strides
in the following way:
For N-D arrays that are inherently 2D (all dimesnsion are one except for two of them),
we use zero as the stride for dimensions equal one.
For other N-D arrays, the non-zero value of strides is calculated and used.

"""

if inherently_2D:
stndrd_strides = tuple(
str_i if sh_i > 1 else 0 for sh_i, str_i in zip(shape, strides)
)
else:
stndrd_strides = [
numpy.prod(shape[i + 1 :]) if strides[i] == 0 else strides[i]
for i in range(ndim - 1)
]
# last dimension
stndrd_strides.append(
1 if strides[ndim - 1] == 0 else strides[ndim - 1]
)
stndrd_strides = tuple(stndrd_strides)

return stndrd_strides


def _validate_axes(x1, x2, axes):
"""Check axes is valid for matmul function."""

def _validate_internal(axes, i, ndim):
if ndim == 1:
iter = 1
if isinstance(axes, int):
axes = (axes,)
elif not isinstance(axes, tuple):
raise TypeError(
f"Axes item {i}: {type(axes)} object cannot be interpreted as an integer."
)

if len(axes) != 1:
raise ValueError(
f"Axes item {i} should be a tuple with a single element, or an integer."
)
else:
iter = 2
if not isinstance(axes, tuple):
raise TypeError(f"Axes item {i} should be a tuple.")
if len(axes) != 2:
raise ValueError(
f"Axes item {i} should be a tuple with 2 elements."
)

for j in range(iter):
if not isinstance(axes[j], int):
raise TypeError(
f"Axes item {i}: {type(axes[j])} object cannot be interpreted as an integer."
)
return axes

if not isinstance(axes, list):
raise TypeError("Axes should be a list.")
else:
if len(axes) != 3:
raise ValueError(
"Axes should be a list of three tuples for inputs and output."
)

axes[0] = _validate_internal(axes[0], 0, x1.ndim)
axes[1] = _validate_internal(axes[1], 1, x2.ndim)

if x1.ndim == 1 and x2.ndim == 1:
if axes[2] != ():
raise TypeError("Axes item 2 should be an empty tuple.")
elif x1.ndim == 1 or x2.ndim == 1:
axes[2] = _validate_internal(axes[2], 2, 1)
else:
axes[2] = _validate_internal(axes[2], 2, 2)

return axes


def dpnp_dot(a, b, /, out=None, *, conjugate=False):
"""
Return the dot product of two arrays.
Expand Down Expand Up @@ -302,6 +377,7 @@ def dpnp_matmul(
casting="same_kind",
order="K",
dtype=None,
axes=None,
):
"""
Return the matrix product of two arrays.
Expand All @@ -327,6 +403,22 @@ def dpnp_matmul(

res_usm_type, exec_q = get_usm_allocations([x1, x2])

if axes is not None:
axes = _validate_axes(x1, x2, axes)

axes_x1, axes_x2, axes_res = axes
axes_x1 = normalize_axis_tuple(axes_x1, x1.ndim, "axis")
axes_x2 = normalize_axis_tuple(axes_x2, x2.ndim, "axis")
# Move the axes that are going to be used in matrix product,
# to the end of "x1" and "x2"
x1 = dpnp.moveaxis(x1, axes_x1, (-2, -1)) if x1.ndim != 1 else x1
x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1)) if x2.ndim != 1 else x2
out_orig = out
if out is not None:
dpnp.check_supported_arrays_type(out)
# out that is passed to the backend should have the correct shape
out = dpnp.moveaxis(out, axes_res, (-2, -1))
vtavana marked this conversation as resolved.
Show resolved Hide resolved

appended_axes = []
if x1_ndim == 1:
x1 = x1[dpnp.newaxis, :]
Expand Down Expand Up @@ -397,9 +489,15 @@ def dpnp_matmul(
x2_shape = x2.shape
res_shape = tuple(tmp_shape) + (x1_shape[-2], x2_shape[-1])

# handling a special case to provide a similar result to NumPy
if out is not None and x1.shape == (1, 0) and x2.shape == (0, 1):
res_shape = (0,)
appended_axes = []

result = _create_result_array(
x1, x2, out, res_shape, gemm_dtype, res_usm_type, exec_q
)

# calculate result
if result.size == 0:
pass
Expand Down Expand Up @@ -471,12 +569,25 @@ def dpnp_matmul(

if gemm_dtype != res_dtype:
result = dpnp.astype(result, res_dtype, copy=False)

if out is None:
if axes is not None:
# Move the result to the appropriate axes of out array
if len(axes_res) == 2:
result = dpnp.moveaxis(result, (-2, -1), axes_res)
elif len(axes_res) == 1:
result = dpnp.moveaxis(result, (-1,), axes_res)
return result
# If `order` was not passed as default
# we need to update it to match the passed `order`.
if order not in ["k", "K"]:
elif order not in ["k", "K"]:
return dpnp.array(result, copy=False, order=order)
else:
return result
else:
return dpnp.get_result_array(result, out, casting=casting)
result = dpnp.get_result_array(result, out, casting=casting)
if axes is not None:
if out is result:
# out and out_orig contain the same data but they have different shape
return out_orig
vtavana marked this conversation as resolved.
Show resolved Hide resolved
return result
Loading
Loading