Skip to content

Commit

Permalink
add support for axes keyword to dpnp.matmul (#1705)
Browse files Browse the repository at this point in the history
* update matmul for cupy tests

* address comments

* address more comments

* fix an error

* use a function for error msg

---------

Co-authored-by: Natalia Polina <natalia.polina@intel.com>
Co-authored-by: Anton <100830759+antonwolfy@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 16, 2024
1 parent 2c38ae8 commit 3461932
Show file tree
Hide file tree
Showing 7 changed files with 595 additions and 61 deletions.
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace mkl_blas = oneapi::mkl::blas;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
oneapi::mkl::transpose,
oneapi::mkl::transpose,
const std::int64_t,
Expand All @@ -64,7 +64,7 @@ static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
[dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event gemm_impl(sycl::queue exec_q,
static sycl::event gemm_impl(sycl::queue &exec_q,
oneapi::mkl::transpose transA,
oneapi::mkl::transpose transB,
const std::int64_t m,
Expand Down Expand Up @@ -130,7 +130,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
gemm(sycl::queue exec_q,
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ namespace ext
namespace blas
{
extern std::pair<sycl::event, sycl::event>
gemm(sycl::queue exec_q,
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
gemm_batch(sycl::queue exec_q,
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
sycl::queue,
sycl::queue &,
const std::int64_t,
const std::int64_t,
const std::int64_t,
Expand All @@ -69,7 +69,7 @@ static gemm_batch_impl_fn_ptr_t
gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event gemm_batch_impl(sycl::queue exec_q,
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
const std::int64_t m,
const std::int64_t n,
const std::int64_t k,
Expand Down Expand Up @@ -145,7 +145,7 @@ static sycl::event gemm_batch_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
gemm_batch(sycl::queue exec_q,
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
56 changes: 52 additions & 4 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,53 @@ def matmul(
order="K",
dtype=None,
subok=True,
signature=None,
extobj=None,
axes=None,
axis=None,
):
"""
Matrix product of two arrays.
For full documentation refer to :obj:`numpy.matmul`.
Parameters
----------
x1 : {dpnp_array, usm_ndarray}
First input array.
x2 : {dpnp_array, usm_ndarray}
Second input array.
out : {dpnp.ndarray, usm_ndarray}, optional
Alternative output array in which to place the result. It must have
a shape that matches the signature `(n,k),(k,m)->(n,m)` but the type
(of the calculated values) will be cast if necessary. Default: ``None``.
dtype : dtype, optional
Type to use in computing the matrix product. By default, the returned
array will have data type that is determined by considering
Promotion Type Rule and device capabilities.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
Controls what kind of data casting may occur. Default: ``"same_kind"``.
order : {"C", "F", "A", "K", None}, optional
Memory layout of the newly output array, if parameter `out` is ``None``.
Default: "K".
axes : list of tuples, optional
A list of tuples with indices of axes the matrix product should operate on.
For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base elements
are 2d matrices and these are taken to be stored in the two last axes of each
argument. The corresponding axes keyword would be [(-2, -1), (-2, -1), (-2, -1)].
Default: ``None``.
Returns
-------
out : dpnp.ndarray
Returns the matrix product of the inputs.
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.
Limitations
-----------
Input arrays and parameter `out` are supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Keyword argument `subok` is currently unsupported.
Input array data types are limited by supported DPNP :ref:`Data types`.
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
only supported with their default value.
Otherwise ``NotImplementedError`` exception will be raised.
See Also
--------
Expand Down Expand Up @@ -338,6 +373,18 @@ def matmul(
raise NotImplementedError(
"subok keyword argument is only supported by its default value."
)
elif signature is not None:
raise NotImplementedError(
"signature keyword argument is only supported by its default value."
)
elif extobj is not None:
raise NotImplementedError(
"extobj keyword argument is only supported by its default value."
)
elif axis is not None:
raise NotImplementedError(
"axis keyword argument is only supported by its default value."
)
else:
return dpnp_matmul(
x1,
Expand All @@ -346,6 +393,7 @@ def matmul(
casting=casting,
order=order,
dtype=dtype,
axes=axes,
)


Expand Down
Loading

0 comments on commit 3461932

Please sign in to comment.