Skip to content

Commit

Permalink
Add parameter out in dpnp.dot() (#1327)
Browse files Browse the repository at this point in the history
  • Loading branch information
antonwolfy authored Mar 5, 2023
1 parent 648612d commit 2224ce2
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 72 deletions.
35 changes: 24 additions & 11 deletions dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# cython: language_level=3
# -*- coding: utf-8 -*-
# *****************************************************************************
# Copyright (c) 2016-2020, Intel Corporation
# Copyright (c) 2016-2023, Intel Corporation
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -65,8 +65,9 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_2in_1out_matmul_t)(c_dpctl.DPCTLSyclQue
const shape_elem_type *, const shape_elem_type * ,
const c_dpctl.DPCTLEventVectorRef)

cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp_descriptor in_array2):

cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1,
utils.dpnp_descriptor in_array2,
utils.dpnp_descriptor out=None):
cdef shape_type_c shape1, shape2

shape1 = in_array1.shape
Expand All @@ -78,6 +79,7 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp

# get the FPTR data structure
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_DOT_EXT, param1_type, param2_type)
cdef utils.dpnp_descriptor result

ndim1 = in_array1.ndim
ndim2 = in_array2.ndim
Expand All @@ -89,7 +91,7 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp
elif ndim1 == 1 and ndim2 == 1:
result_shape = ()
elif ndim1 == 1: # ndim2 > 1
result_shape = shape2[:-1]
result_shape = shape2[::-2] if ndim2 == 2 else shape2[::2]
elif ndim2 == 1: # ndim1 > 1
result_shape = shape1[:-1]
else:
Expand All @@ -101,13 +103,24 @@ cpdef utils.dpnp_descriptor dpnp_dot(utils.dpnp_descriptor in_array1, utils.dpnp

result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(in_array1, in_array2)

# create result array with type given by FPTR data
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
kernel_data.return_type,
None,
device=result_sycl_device,
usm_type=result_usm_type,
sycl_queue=result_sycl_queue)
if out is None:
# create result array with type given by FPTR data
result = utils.create_output_descriptor(result_shape,
kernel_data.return_type,
None,
device=result_sycl_device,
usm_type=result_usm_type,
sycl_queue=result_sycl_queue)
else:
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
if out.dtype != result_type:
utils.checker_throw_value_error('dot', 'out.dtype', out.dtype, result_type)
if out.shape != result_shape:
utils.checker_throw_value_error('dot', 'out.shape', out.shape, result_shape)

result = out

utils.get_common_usm_allocation(in_array1, result) # check USM allocation is common

cdef shape_type_c result_strides = utils.strides_to_vector(result.strides, result.shape)
cdef shape_type_c in_array1_shape = in_array1.shape
Expand Down
3 changes: 2 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,8 @@ def diagonal(input, offset=0, axis1=0, axis2=1):

return dpnp.diagonal(input, offset, axis1, axis2)

# 'dot',
def dot(self, other, out=None):
return dpnp.dot(self, other, out)

@property
def dtype(self):
Expand Down
67 changes: 40 additions & 27 deletions dpnp/dpnp_iface_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
from dpnp.dpnp_algo import *
from dpnp.dpnp_utils import *
import dpnp
import dpnp.config as config

import numpy
import dpctl.tensor as dpt


__all__ = [
Expand All @@ -62,18 +62,25 @@
]


def dot(x1, x2, **kwargs):
def dot(x1, x2, out=None, **kwargs):
"""
Returns the dot product of `x1` and `x2`.
Dot product of `x1` and `x2`.
For full documentation refer to :obj:`numpy.dot`.
Returns
-------
y : dpnp.ndarray
Returns the dot product of `x1` and `x2`.
If `out` is given, then it is returned.
Limitations
-----------
Parameters ``x1`` and ``x2`` are supported as :obj:`dpnp.ndarray` of the same type.
Keyword arguments ``kwargs`` are currently unsupported.
Otherwise the functions will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Parameters `x1` and `x2` are supported as either scalar, :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`, but both `x1` and `x2` can not be scalars at the same time.
Keyword argument ``kwargs`` is currently unsupported.
Otherwise the functions will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
See Also
--------
Expand All @@ -82,31 +89,37 @@ def dot(x1, x2, **kwargs):
Examples
--------
>>> import dpnp as np
>>> np.dot(3, 4)
12
>>> a = np.array([1, 2, 3])
>>> b = np.array([1, 2, 3])
>>> np.dot(a, b)
>>> import dpnp as dp
>>> a = dp.array([1, 2, 3])
>>> b = dp.array([1, 2, 3])
>>> dp.dot(a, b)
14
"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False)
if x1_desc and x2_desc and not kwargs:
# TODO: remove fallback with scalars when muliply backend func will support strides
if(x1_desc.ndim == 0 and x2_desc.strides is not None
or x2_desc.ndim == 0 and x1_desc.strides is not None):
pass
elif (x1_desc.ndim >= 1 and x2_desc.ndim > 1 and x1_desc.shape[-1] != x2_desc.shape[-2]):
pass
elif (x1_desc.ndim > 0 and x2_desc.ndim == 1 and x1_desc.shape[-1] != x2_desc.shape[0]):
pass
else:
return dpnp_dot(x1_desc, x2_desc).get_pyobj()
if kwargs:
pass
elif dpnp.isscalar(x1) and dpnp.isscalar(x2):
# at least either x1 or x2 has to be an array
pass
else:
# get USM type and queue to copy scalar from the host memory into a USM allocation
usm_type, queue = get_usm_allocations([x1, x2]) if dpnp.isscalar(x1) or dpnp.isscalar(x2) else (None, None)

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_strides=False, copy_when_nondefault_queue=False,
alloc_usm_type=usm_type, alloc_queue=queue)
if x1_desc and x2_desc:
if out is not None:
if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)):
raise TypeError("return array must be of supported array type")
out_desc = dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False)
else:
out_desc = None
return dpnp_dot(x1_desc, x2_desc, out=out_desc).get_pyobj()

return call_origin(numpy.dot, x1, x2, **kwargs)
return call_origin(numpy.dot, x1, x2, out=out, **kwargs)


def einsum(*args, **kwargs):
Expand Down
4 changes: 0 additions & 4 deletions tests/skipped_tests.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumLarge_param_9_{opt
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_float
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_64_{shape=((2,), (2, 4)), trans_a=True, trans_b=True}::test_dot
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_65_{shape=((2,), (2, 4)), trans_a=True, trans_b=False}::test_dot
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_66_{shape=((2,), (2, 4)), trans_a=False, trans_b=True}::test_dot
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_67_{shape=((2,), (2, 4)), trans_a=False, trans_b=False}::test_dot
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two
Expand Down
5 changes: 0 additions & 5 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -812,10 +812,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWith
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_64_{shape=((2,), (2, 4)), trans_a=True, trans_b=True}::test_dot
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_65_{shape=((2,), (2, 4)), trans_a=True, trans_b=False}::test_dot
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_66_{shape=((2,), (2, 4)), trans_a=False, trans_b=True}::test_dot
tests/third_party/cupy/linalg_tests/test_product.py::TestDot_param_67_{shape=((2,), (2, 4)), trans_a=False, trans_b=False}::test_dot
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large
tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two
Expand All @@ -827,7 +823,6 @@ tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transpose
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_tensordot_zero_dim
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out_f_contiguous
tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_multidim_vdot
Expand Down
23 changes: 11 additions & 12 deletions tests/test_dot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import pytest
from .helper import get_all_dtypes

import dpnp as inp

import numpy
from numpy.testing import (
assert_allclose,
assert_array_equal
)


@pytest.mark.parametrize("type",
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
ids=['float64', 'float32', 'int64', 'int32'])
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
def test_dot_ones(type):
n = 10**5
a = numpy.ones(n, dtype=type)
Expand All @@ -17,12 +20,10 @@ def test_dot_ones(type):

result = inp.dot(ia, ib)
expected = numpy.dot(a, b)
numpy.testing.assert_array_equal(expected, result)
assert_array_equal(expected, result)


@pytest.mark.parametrize("type",
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
ids=['float64', 'float32', 'int64', 'int32'])
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
def test_dot_arange(type):
n = 10**2
m = 10**3
Expand All @@ -33,12 +34,10 @@ def test_dot_arange(type):

result = inp.dot(ia, ib)
expected = numpy.dot(a, b)
numpy.testing.assert_allclose(expected, result)
assert_allclose(expected, result)


@pytest.mark.parametrize("type",
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
ids=['float64', 'float32', 'int64', 'int32'])
@pytest.mark.parametrize("type", get_all_dtypes(no_bool=True, no_complex=True))
def test_multi_dot(type):
n = 16
a = inp.reshape(inp.arange(n, dtype=type), (4, 4))
Expand All @@ -53,4 +52,4 @@ def test_multi_dot(type):

result = inp.linalg.multi_dot([a, b, c, d])
expected = numpy.linalg.multi_dot([a1, b1, c1, d1])
numpy.testing.assert_array_equal(expected, result)
assert_array_equal(expected, result)
28 changes: 17 additions & 11 deletions tests/test_sycl_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_1in_1out(func, data, device):
x = dpnp.array(data, device=device)
result = getattr(dpnp, func)(x)

numpy.testing.assert_array_equal(result, expected)
assert_array_equal(result, expected)

expected_queue = x.get_array().sycl_queue
result_queue = result.get_array().sycl_queue
Expand All @@ -320,6 +320,9 @@ def test_1in_1out(func, data, device):
pytest.param("divide",
[0., 1., 2., 3., 4.],
[4., 4., 4., 4., 4.]),
pytest.param("dot",
[[0., 1., 2.], [3., 4., 5.]],
[[4., 4.], [4., 4.], [4., 4.]]),
pytest.param("floor_divide",
[1., 2., 3., 4.],
[2.5, 2.5, 2.5, 2.5]),
Expand Down Expand Up @@ -364,7 +367,7 @@ def test_2in_1out(func, data1, data2, device):
x2 = dpnp.array(data2, device=device)
result = getattr(dpnp, func)(x1, x2)

numpy.testing.assert_array_equal(result, expected)
assert_array_equal(result, expected)

assert_sycl_queue_equal(result.sycl_queue, x1.sycl_queue)
assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue)
Expand Down Expand Up @@ -539,6 +542,9 @@ def test_random_state(func, args, kwargs, device, usm_type):
pytest.param("divide",
[0., 1., 2., 3., 4.],
[4., 4., 4., 4., 4.]),
pytest.param("dot",
[[0., 1., 2.], [3., 4., 5.]],
[[4., 4.], [4., 4.], [4., 4.]]),
pytest.param("floor_divide",
[1., 2., 3., 4.],
[2.5, 2.5, 2.5, 2.5]),
Expand Down Expand Up @@ -571,20 +577,20 @@ def test_random_state(func, args, kwargs, device, usm_type):
def test_out(func, data1, data2, device):
x1_orig = numpy.array(data1)
x2_orig = numpy.array(data2)
expected = numpy.empty(x1_orig.size)
numpy.add(x1_orig, x2_orig, out=expected)
np_out = getattr(numpy, func)(x1_orig, x2_orig)
expected = numpy.empty_like(np_out)
getattr(numpy, func)(x1_orig, x2_orig, out=expected)

x1 = dpnp.array(data1, device=device)
x2 = dpnp.array(data2, device=device)
result = dpnp.empty(x1.size, device=device)
dpnp.add(x1, x2, out=result)
dp_out = getattr(dpnp, func)(x1, x2)
result = dpnp.empty_like(dp_out)
getattr(dpnp, func)(x1, x2, out=result)

numpy.testing.assert_array_equal(result, expected)
assert_array_equal(result, expected)

expected_queue = x1.get_array().sycl_queue
result_queue = result.get_array().sycl_queue

assert_sycl_queue_equal(result_queue, expected_queue)
assert_sycl_queue_equal(result.sycl_queue, x1.sycl_queue)
assert_sycl_queue_equal(result.sycl_queue, x2.sycl_queue)


@pytest.mark.parametrize("device",
Expand Down
19 changes: 19 additions & 0 deletions tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,22 @@ def test_meshgrid(usm_type_x, usm_type_y):
z = dp.meshgrid(x, y)
assert z[0].usm_type == usm_type_x
assert z[1].usm_type == usm_type_y

@pytest.mark.parametrize(
"func,data1,data2",
[
pytest.param("dot",
[[0., 1., 2.], [3., 4., 5.]],
[[4., 4.], [4., 4.], [4., 4.]]),
],
)
@pytest.mark.parametrize("usm_type_x", list_of_usm_types, ids=list_of_usm_types)
@pytest.mark.parametrize("usm_type_y", list_of_usm_types, ids=list_of_usm_types)
def test_2in_1out(func, data1, data2, usm_type_x, usm_type_y):
x = dp.array(data1, usm_type = usm_type_x)
y = dp.array(data2, usm_type = usm_type_y)
z = getattr(dp, func)(x, y)

assert x.usm_type == usm_type_x
assert y.usm_type == usm_type_y
assert z.usm_type == du.get_coerced_usm_type([usm_type_x, usm_type_y])
1 change: 0 additions & 1 deletion tests/third_party/cupy/linalg_tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
'trans_a': [True, False],
'trans_b': [True, False],
}))
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
@testing.gpu
class TestDot(unittest.TestCase):

Expand Down

0 comments on commit 2224ce2

Please sign in to comment.