Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for axes keyword to dpnp.matmul #1705

Merged
merged 10 commits into from
Feb 16, 2024
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace mkl_blas = oneapi::mkl::blas;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
oneapi::mkl::transpose,
oneapi::mkl::transpose,
const std::int64_t,
Expand All @@ -64,7 +64,7 @@ static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
[dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event gemm_impl(sycl::queue exec_q,
static sycl::event gemm_impl(sycl::queue &exec_q,
oneapi::mkl::transpose transA,
oneapi::mkl::transpose transB,
const std::int64_t m,
Expand Down Expand Up @@ -130,7 +130,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
gemm(sycl::queue exec_q,
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ namespace ext
namespace blas
{
extern std::pair<sycl::event, sycl::event>
gemm(sycl::queue exec_q,
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
gemm_batch(sycl::queue exec_q,
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
sycl::queue,
sycl::queue &,
const std::int64_t,
const std::int64_t,
const std::int64_t,
Expand All @@ -69,7 +69,7 @@ static gemm_batch_impl_fn_ptr_t
gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event gemm_batch_impl(sycl::queue exec_q,
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
const std::int64_t m,
const std::int64_t n,
const std::int64_t k,
Expand Down Expand Up @@ -145,7 +145,7 @@ static sycl::event gemm_batch_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
gemm_batch(sycl::queue exec_q,
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
56 changes: 52 additions & 4 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,53 @@ def matmul(
order="K",
dtype=None,
subok=True,
signature=None,
extobj=None,
axes=None,
axis=None,
):
"""
Matrix product of two arrays.

For full documentation refer to :obj:`numpy.matmul`.

Parameters
----------
x1 : {dpnp_array, usm_ndarray}
First input array.
x2 : {dpnp_array, usm_ndarray}
Second input array.
out : {dpnp.ndarray, usm_ndarray}, optional
Alternative output array in which to place the result. It must have
a shape that matches the signature `(n,k),(k,m)->(n,m)` but the type
(of the calculated values) will be cast if necessary. Default: ``None``.
dtype : dtype, optional
Type to use in computing the matrix product. By default, the returned
array will have data type that is determined by considering
Promotion Type Rule and device capabilities.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
Controls what kind of data casting may occur. Default: ``"same_kind"``.
order : {"C", "F", "A", "K", None}, optional
Memory layout of the newly output array, if parameter `out` is ``None``.
Default: "K".
axes : list of tuples, optional
A list of tuples with indices of axes the matrix product should operate on.
For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base elements
are 2d matrices and these are taken to be stored in the two last axes of each
argument. The corresponding axes keyword would be [(-2, -1), (-2, -1), (-2, -1)].
Default: ``None``.

Returns
-------
out : dpnp.ndarray
Returns the matrix product of the inputs.
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.

Limitations
-----------
Input arrays and parameter `out` are supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Keyword argument `subok` is currently unsupported.
Input array data types are limited by supported DPNP :ref:`Data types`.
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
only supported with their default value.
Otherwise ``NotImplementedError`` exception will be raised.

See Also
--------
Expand Down Expand Up @@ -338,6 +373,18 @@ def matmul(
raise NotImplementedError(
"subok keyword argument is only supported by its default value."
)
elif signature is not None:
raise NotImplementedError(
"signature keyword argument is only supported by its default value."
)
elif extobj is not None:
raise NotImplementedError(
"extobj keyword argument is only supported by its default value."
)
elif axis is not None:
raise NotImplementedError(
"axis keyword argument is only supported by its default value."
)
else:
return dpnp_matmul(
x1,
Expand All @@ -346,6 +393,7 @@ def matmul(
casting=casting,
order=order,
dtype=dtype,
axes=axes,
)


Expand Down
Loading
Loading