Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for axes keyword to dpnp.matmul #1705

Merged
merged 10 commits into from
Feb 16, 2024
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/blas/gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace mkl_blas = oneapi::mkl::blas;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue,
typedef sycl::event (*gemm_impl_fn_ptr_t)(sycl::queue &,
oneapi::mkl::transpose,
oneapi::mkl::transpose,
const std::int64_t,
Expand All @@ -64,7 +64,7 @@ static gemm_impl_fn_ptr_t gemm_dispatch_table[dpctl_td_ns::num_types]
[dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event gemm_impl(sycl::queue exec_q,
static sycl::event gemm_impl(sycl::queue &exec_q,
oneapi::mkl::transpose transA,
oneapi::mkl::transpose transB,
const std::int64_t m,
Expand Down Expand Up @@ -130,7 +130,7 @@ static sycl::event gemm_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
gemm(sycl::queue exec_q,
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
4 changes: 2 additions & 2 deletions dpnp/backend/extensions/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ namespace ext
namespace blas
{
extern std::pair<sycl::event, sycl::event>
gemm(sycl::queue exec_q,
gemm(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
const std::vector<sycl::event> &depends);

extern std::pair<sycl::event, sycl::event>
gemm_batch(sycl::queue exec_q,
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/blas/gemm_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*gemm_batch_impl_fn_ptr_t)(
sycl::queue,
sycl::queue &,
const std::int64_t,
const std::int64_t,
const std::int64_t,
Expand All @@ -69,7 +69,7 @@ static gemm_batch_impl_fn_ptr_t
gemm_batch_dispatch_table[dpctl_td_ns::num_types][dpctl_td_ns::num_types];

template <typename Tab, typename Tc>
static sycl::event gemm_batch_impl(sycl::queue exec_q,
static sycl::event gemm_batch_impl(sycl::queue &exec_q,
const std::int64_t m,
const std::int64_t n,
const std::int64_t k,
Expand Down Expand Up @@ -145,7 +145,7 @@ static sycl::event gemm_batch_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
gemm_batch(sycl::queue exec_q,
gemm_batch(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray matrixA,
dpctl::tensor::usm_ndarray matrixB,
dpctl::tensor::usm_ndarray resultC,
Expand Down
97 changes: 92 additions & 5 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,53 @@ def matmul(
order="K",
dtype=None,
subok=True,
signature=None,
extobj=None,
axes=None,
axis=None,
):
"""
Matrix product of two arrays.

For full documentation refer to :obj:`numpy.matmul`.

Parameters
----------
x1 : {dpnp_array, usm_ndarray}
First input array.
x2 : {dpnp_array, usm_ndarray}
Second input array.
out : {dpnp.ndarray, usm_ndarray}, optional
Alternative output array in which to place the result. It must have
a shape that matches the signature `(n,k),(k,m)->(n,m)` but the type
(of the calculated values) will be cast if necessary. Default: ``None``.
dtype : dtype, optional
Type to use in computing the matrix product. By default, the returned
array will have data type that is determined by considering
Promotion Type Rule and device capabilities.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
Controls what kind of data casting may occur. Default: ``"same_kind"``.
order : {"C", "F", "A", "K", None}, optional
Memory layout of the newly output array, if parameter `out` is ``None``.
Default: "K".
axes : list of tuples, optional
A list of tuples with indices of axes the matrix product should operate on.
For instance, for the signature of ``(i,j),(j,k)->(i,k)``, the base elements
are 2d matrices and these are taken to be stored in the two last axes of each
argument. The corresponding axes keyword would be [(-2, -1), (-2, -1), (-2, -1)].
Default: ``None``.

Returns
-------
out : dpnp.ndarray
Returns the matrix product of the inputs.
This is a 0-d array only when both `x1`, `x2` are 1-d vectors.

Limitations
-----------
Input arrays and parameter `out` are supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Keyword argument `subok` is currently unsupported.
Input array data types are limited by supported DPNP :ref:`Data types`.
Keyword arguments `subok`, `signature`, `extobj`, and `axis` are
only supported with their default value.
Otherwise ``NotImplementedError`` exception will be raised.

See Also
--------
Expand Down Expand Up @@ -338,15 +373,67 @@ def matmul(
raise NotImplementedError(
"subok keyword argument is only supported by its default value."
)
elif signature is not None:
raise NotImplementedError(
"signature keyword argument is only supported by its default value."
)
elif extobj is not None:
raise NotImplementedError(
"extobj keyword argument is only supported by its default value."
)
elif axis is not None:
raise NotImplementedError(
"axis keyword argument is only supported by its default value."
)
else:
return dpnp_matmul(
if axes is not None:
if not isinstance(axes, list):
raise TypeError("Axes should be a list.")
else:
if len(axes) != 3:
raise ValueError(
"Axes should be a list of three tuples for inputs and output."
)

for i in range(3):
if not isinstance(axes[i], tuple):
raise TypeError(f"Axes item {i} should be a tuple.")
vtavana marked this conversation as resolved.
Show resolved Hide resolved
if len(axes[i]) != 2:
raise ValueError(
f"Axes item {i} should be a tuple with 2 elements."
)

for j in range(2):
if not isinstance(axes[i][j], int):
raise TypeError("Axes must be an integer.")

axes_x1, axes_x2, axes_res = axes
# Move the axes that are going to be used in matrix product,
# to the end of "x1" and "x2"
x1 = dpnp.moveaxis(x1, axes_x1, (-2, -1))
x2 = dpnp.moveaxis(x2, axes_x2, (-2, -1))
out_orig = out
if out is not None:
dpnp.check_supported_arrays_type(x1, x2)
vtavana marked this conversation as resolved.
Show resolved Hide resolved
# out that is passed to the backend should have the correct shape
out = dpnp.moveaxis(out, axes_res, (-2, -1))

result = dpnp_matmul(
x1,
x2,
out=out,
casting=casting,
order=order,
dtype=dtype,
)
if axes is not None:
if out is result:
# out and out_orig contain the same data but they have different shape
return out_orig
# Move the result to the appropriate axes of out array
result = dpnp.moveaxis(result, (-2, -1), axes_res)

return result


def outer(x1, x2, out=None):
Expand Down
49 changes: 34 additions & 15 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,21 +116,9 @@ def _gemm_batch_matmul(exec_q, x1, x2, res, x1_is_2D, x2_is_2D, dev_tasks_list):
x2_strides = x2.strides
res_strides = res.strides

# when shape along any particular dimension is 1,
# the stride along that dimension is not a
# meaningful number and is undefined. Here, we
# standardizing strides before continuing,
# setting stride to 0 if the shape along that axis is <=1
if x1_is_2D:
x1_strides = tuple(
str_i if sh_i > 1 else 0
for sh_i, str_i in zip(x1.shape, x1_strides)
)
if x2_is_2D:
x2_strides = tuple(
str_i if sh_i > 1 else 0
for sh_i, str_i in zip(x2.shape, x2_strides)
)
# need to standardize to use in ti._contract_iter2
x1_strides = _standardize_strides(x1_strides, x1_is_2D, x1.shape, x1.ndim)
x2_strides = _standardize_strides(x2_strides, x2_is_2D, x2.shape, x2.ndim)

batch_size = res.shape[:-2][0]
stridea = x1_strides[0]
Expand Down Expand Up @@ -220,6 +208,37 @@ def _op_res_dtype(*arrays, dtype, casting, sycl_queue):
return op_dtype, res_dtype


def _standardize_strides(strides, inherently_2D, shape, ndim):
"""
Standardizing the strides.
vtavana marked this conversation as resolved.
Show resolved Hide resolved

When shape of an array along any particular dimension is 1, the stride
along that dimension is undefined. This functions standardize the strides
in the following way:
For N-D arrays that are inherently 2D (all dimesnsion are one except for two of them),
we use zero as the stride for dimensions equal one.
For other N-D arrays, the non-zero value of strides is calculated and used.

"""

if inherently_2D:
stndrd_strides = tuple(
str_i if sh_i > 1 else 0 for sh_i, str_i in zip(shape, strides)
)
else:
stndrd_strides = [
numpy.prod(shape[i + 1 :]) if strides[i] == 0 else strides[i]
for i in range(ndim - 1)
]
# last dimension
stndrd_strides.append(
1 if strides[ndim - 1] == 0 else strides[ndim - 1]
)
stndrd_strides = tuple(stndrd_strides)

return stndrd_strides


def dpnp_dot(a, b, /, out=None, *, conjugate=False):
"""
Return the dot product of two arrays.
Expand Down
96 changes: 95 additions & 1 deletion tests/test_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2634,6 +2634,57 @@ def test_matmul_dtype(self, dtype, shape_pair):
expected = numpy.matmul(a1, a2, dtype=dtype)
assert_dtype_allclose(result, expected)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize(
"axes",
[
[(-3, -1), (0, 2), (-2, -3)],
[(3, 1), (2, 0), (3, 1)],
[(3, 1), (2, 0), (0, 1)],
],
)
def test_matmul_axes(self, dtype, axes):
a = numpy.array(
numpy.random.uniform(-10, 10, 120), dtype=dtype
).reshape(2, 5, 3, 4)
b = numpy.array(
numpy.random.uniform(-10, 10, 120), dtype=dtype
).reshape(4, 2, 5, 3)
ia = dpnp.array(a)
ib = dpnp.array(b)

result = dpnp.matmul(ia, ib, axes=axes)
print(result.shape)
expected = numpy.matmul(a, b, axes=axes)
# TODO: investigate the effect of factor, see SAT-6700
assert_dtype_allclose(result, expected, factor=24)

@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize(
"axes, out_shape",
[
([(-3, -1), (0, 2), (-2, -3)], (2, 5, 5, 3)),
([(3, 1), (2, 0), (3, 1)], (2, 4, 3, 4)),
([(3, 1), (2, 0), (1, 2)], (2, 4, 4, 3)),
],
)
def test_matmul_axes_out(self, dtype, axes, out_shape):
a = numpy.array(
numpy.random.uniform(-10, 10, 120), dtype=dtype
).reshape(2, 5, 3, 4)
b = numpy.array(
numpy.random.uniform(-10, 10, 120), dtype=dtype
).reshape(4, 2, 5, 3)
ia = dpnp.array(a)
ib = dpnp.array(b)

out_dp = dpnp.empty(out_shape, dtype=dtype)
result = dpnp.matmul(ia, ib, axes=axes, out=out_dp)
assert result is out_dp
expected = numpy.matmul(a, b, axes=axes)
# TODO: investigate the effect of factor, see SAT-6700
assert_dtype_allclose(result, expected, factor=24)

@pytest.mark.parametrize("dtype1", get_all_dtypes(no_bool=True))
@pytest.mark.parametrize(
"dtype2", get_all_dtypes(no_bool=True, no_none=True)
Expand Down Expand Up @@ -2822,9 +2873,52 @@ def test_matmul_casting(self):
with pytest.raises(TypeError):
dpnp.matmul(a1, a2, out=res, casting="safe")

def test_matmul_subok(self):
def test_matmul_not_implemented(self):
a1 = dpnp.arange(2 * 4).reshape(2, 4)
a2 = dpnp.arange(4 * 3).reshape(4, 3)

with pytest.raises(NotImplementedError):
dpnp.matmul(a1, a2, subok=False)

with pytest.raises(NotImplementedError):
dpnp.matmul(
a1, a2, signature=(dpnp.float32, dpnp.float32, dpnp.float32)
)

def custom_error_callback(err):
print("Custom error callback triggered with error:", err)

with pytest.raises(NotImplementedError):
dpnp.matmul(a1, a2, extobj=[32, 1, custom_error_callback])

with pytest.raises(NotImplementedError):
dpnp.matmul(a1, a2, axis=2)

def test_matmul_axes(self):
a1 = dpnp.arange(120).reshape(2, 5, 3, 4)
a2 = dpnp.arange(120).reshape(4, 2, 5, 3)

# axes must be a list
axes = ((3, 1), (2, 0), (0, 1))
with pytest.raises(TypeError):
dpnp.matmul(a1, a2, axes=axes)

# axes must be be a list of three tuples
axes = [(3, 1), (2, 0)]
with pytest.raises(ValueError):
dpnp.matmul(a1, a2, axes=axes)

# axes items should be a tuple
axes = [(3, 1), (2, 0), [0, 1]]
with pytest.raises(TypeError):
dpnp.matmul(a1, a2, axes=axes)

# axes items should be a tuple with 2 elements
axes = [(3, 1), (2, 0), (0, 1, 2)]
with pytest.raises(ValueError):
dpnp.matmul(a1, a2, axes=axes)

# axes must be an integer
axes = [(3, 1), (2, 0), (0.0, 1)]
with pytest.raises(TypeError):
dpnp.matmul(a1, a2, axes=axes)
Loading
Loading