From 95fdbeccfd0da8747b846a844664ec949bcf856d Mon Sep 17 00:00:00 2001 From: vtavana <120411540+vtavana@users.noreply.github.com> Date: Wed, 21 Feb 2024 03:03:55 -0600 Subject: [PATCH] update `dpnp.cross` (#1715) * update dpnp.cross * rename test file in github actions * address comments * use direct call of multiply etc from dpctl to have async calculation * add braodcasting --- .github/workflows/conda-package.yml | 2 +- dpnp/backend/include/dpnp_iface_fptr.hpp | 2 - .../kernels/dpnp_krnl_mathematical.cpp | 45 --- dpnp/dpnp_algo/dpnp_algo.pxd | 2 - dpnp/dpnp_algo/dpnp_algo.pyx | 71 ---- dpnp/dpnp_algo/dpnp_algo_mathematical.pxi | 9 - dpnp/dpnp_iface_mathematical.py | 166 +++++++-- dpnp/dpnp_utils/dpnp_utils_linearalgebra.py | 158 ++++++++- tests/test_mathematical.py | 28 -- tests/{test_dot.py => test_product.py} | 334 ++++++++++++++---- tests/test_usm_type.py | 1 + .../cupy/linalg_tests/test_product.py | 7 +- 12 files changed, 560 insertions(+), 265 deletions(-) rename tests/{test_dot.py => test_product.py} (66%) diff --git a/.github/workflows/conda-package.yml b/.github/workflows/conda-package.yml index b90bb316639..9bd38965ab5 100644 --- a/.github/workflows/conda-package.yml +++ b/.github/workflows/conda-package.yml @@ -18,7 +18,6 @@ env: TEST_SCOPE: >- test_arraycreation.py test_amin_amax.py - test_dot.py test_dparray.py test_copy.py test_fft.py @@ -26,6 +25,7 @@ env: test_logic.py test_manipulation.py test_mathematical.py + test_product.py test_random_state.py test_sort.py test_special.py diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 24b01f5ff11..9660d290e5c 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -105,8 +105,6 @@ enum class DPNPFuncName : size_t DPNP_FN_COUNT_NONZERO, /**< Used in numpy.count_nonzero() impl */ DPNP_FN_COV, /**< Used in numpy.cov() impl */ DPNP_FN_CROSS, /**< Used in numpy.cross() impl */ - DPNP_FN_CROSS_EXT, /**< Used in numpy.cross() impl, requires extra - parameters */ DPNP_FN_CUMPROD, /**< Used in numpy.cumprod() impl */ DPNP_FN_CUMPROD_EXT, /**< Used in numpy.cumprod() impl, requires extra parameters */ diff --git a/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp b/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp index d80ccfa186e..b3552763c82 100644 --- a/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp @@ -311,23 +311,6 @@ void (*dpnp_cross_default_c)(void *, const size_t *) = dpnp_cross_c<_DataType_output, _DataType_input1, _DataType_input2>; -template -DPCTLSyclEventRef (*dpnp_cross_ext_c)(DPCTLSyclQueueRef, - void *, - const void *, - const size_t, - const shape_elem_type *, - const size_t, - const void *, - const size_t, - const shape_elem_type *, - const size_t, - const size_t *, - const DPCTLEventVectorRef) = - dpnp_cross_c<_DataType_output, _DataType_input1, _DataType_input2>; - template class dpnp_cumprod_c_kernel; @@ -1116,31 +1099,6 @@ DPCTLSyclEventRef (*dpnp_trapz_ext_c)(DPCTLSyclQueueRef, const DPCTLEventVectorRef) = dpnp_trapz_c<_DataType_input1, _DataType_input2, _DataType_output>; -template -static void func_map_elemwise_2arg_3type_core(func_map_t &fmap) -{ - ((fmap[DPNPFuncName::DPNP_FN_CROSS_EXT][FT1][FTs] = - {get_floating_res_type(), - (void *)dpnp_cross_ext_c< - func_type_map_t::find_type()>, - func_type_map_t::find_type, - func_type_map_t::find_type>, - get_floating_res_type(), - (void *)dpnp_cross_ext_c< - func_type_map_t::find_type()>, - func_type_map_t::find_type, - func_type_map_t::find_type>}), - ...); -} - -template -static void func_map_elemwise_2arg_3type_helper(func_map_t &fmap) -{ - ((func_map_elemwise_2arg_3type_core(fmap)), ...); -} - void func_map_init_mathematical(func_map_t &fmap) { fmap[DPNPFuncName::DPNP_FN_ABSOLUTE][eft_INT][eft_INT] = { @@ -1402,8 +1360,5 @@ void func_map_init_mathematical(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_TRAPZ_EXT][eft_DBL][eft_DBL] = { eft_DBL, (void *)dpnp_trapz_ext_c}; - func_map_elemwise_2arg_3type_helper( - fmap); - return; } diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 3ad23b08fbe..67db6a07f75 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -42,8 +42,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_COPY_EXT DPNP_FN_CORRELATE DPNP_FN_CORRELATE_EXT - DPNP_FN_CROSS - DPNP_FN_CROSS_EXT DPNP_FN_CUMPROD DPNP_FN_CUMPROD_EXT DPNP_FN_CUMSUM diff --git a/dpnp/dpnp_algo/dpnp_algo.pyx b/dpnp/dpnp_algo/dpnp_algo.pyx index fadba02a032..9cbdaf3f1df 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pyx +++ b/dpnp/dpnp_algo/dpnp_algo.pyx @@ -278,77 +278,6 @@ cdef utils.dpnp_descriptor call_fptr_1in_1out_strides(DPNPFuncName fptr_name, return result -cdef utils.dpnp_descriptor call_fptr_2in_1out(DPNPFuncName fptr_name, - utils.dpnp_descriptor x1_obj, - utils.dpnp_descriptor x2_obj, - object dtype=None, - utils.dpnp_descriptor out=None, - object where=True, - func_name=None): - - # Convert type (x1_obj.dtype) to C enum DPNPFuncType - cdef DPNPFuncType x1_c_type = dpnp_dtype_to_DPNPFuncType(x1_obj.dtype) - cdef DPNPFuncType x2_c_type = dpnp_dtype_to_DPNPFuncType(x2_obj.dtype) - - # get the FPTR data structure - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(fptr_name, x1_c_type, x2_c_type) - - result_sycl_device, result_usm_type, result_sycl_queue = utils.get_common_usm_allocation(x1_obj, x2_obj) - - # get FPTR function and return type - cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data, - result_sycl_device.has_aspect_fp64) - cdef DPNPFuncType return_type = ret_type_and_func[0] - cdef fptr_2in_1out_t func = < fptr_2in_1out_t > ret_type_and_func[1] - - result_type = dpnp_DPNPFuncType_to_dtype( < size_t > return_type) - - # Create result array - cdef shape_type_c x1_shape = x1_obj.shape - cdef shape_type_c x2_shape = x2_obj.shape - cdef shape_type_c result_shape = utils.get_common_shape(x1_shape, x2_shape) - cdef utils.dpnp_descriptor result - - if out is None: - """ Create result array with type given by FPTR data """ - result = utils.create_output_descriptor(result_shape, - return_type, - None, - device=result_sycl_device, - usm_type=result_usm_type, - sycl_queue=result_sycl_queue) - else: - if out.dtype != result_type: - utils.checker_throw_value_error(func_name, 'out.dtype', out.dtype, result_type) - if out.shape != result_shape: - utils.checker_throw_value_error(func_name, 'out.shape', out.shape, result_shape) - - result = out - - utils.get_common_usm_allocation(x1_obj, result) # check USM allocation is common - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - """ Call FPTR function """ - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - result.get_data(), - x1_obj.get_data(), - x1_obj.size, - x1_shape.data(), - x1_shape.size(), - x2_obj.get_data(), - x2_obj.size, - x2_shape.data(), - x2_shape.size(), - NULL, - NULL) # dep_events_ref) - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result - cdef utils.dpnp_descriptor call_fptr_2in_1out_strides(DPNPFuncName fptr_name, utils.dpnp_descriptor x1_obj, utils.dpnp_descriptor x2_obj, diff --git a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi index 85f51e52eee..f4aa8873056 100644 --- a/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_mathematical.pxi @@ -36,7 +36,6 @@ and the rest of the library # NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file __all__ += [ - "dpnp_cross", "dpnp_cumprod", "dpnp_cumsum", "dpnp_ediff1d", @@ -60,14 +59,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*ftpr_custom_trapz_2in_1out_with_2size_t)(c_d const c_dpctl.DPCTLEventVectorRef) -cpdef utils.dpnp_descriptor dpnp_cross(utils.dpnp_descriptor x1_obj, - utils.dpnp_descriptor x2_obj, - object dtype=None, - utils.dpnp_descriptor out=None, - object where=True): - return call_fptr_2in_1out(DPNP_FN_CROSS_EXT, x1_obj, x2_obj, dtype, out, where) - - cpdef utils.dpnp_descriptor dpnp_cumprod(utils.dpnp_descriptor x1): # instead of x1.shape, (x1.size, ) is passed to the function # due to the following: diff --git a/dpnp/dpnp_iface_mathematical.py b/dpnp/dpnp_iface_mathematical.py index 7ec846f770c..ace2db7cc08 100644 --- a/dpnp/dpnp_iface_mathematical.py +++ b/dpnp/dpnp_iface_mathematical.py @@ -48,6 +48,7 @@ import dpnp from dpnp.dpnp_array import dpnp_array +from dpnp.dpnp_utils import get_usm_allocations from .dpnp_algo import * from .dpnp_algo.dpnp_elementwise_common import ( @@ -78,6 +79,7 @@ dpnp_trunc, ) from .dpnp_utils import * +from .dpnp_utils.dpnp_utils_linearalgebra import dpnp_cross __all__ = [ "abs", @@ -667,52 +669,152 @@ def copysign( ) -def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): +def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None): """ Return the cross product of two (arrays of) vectors. For full documentation refer to :obj:`numpy.cross`. - Limitations - ----------- - Parameters `x1` and `x2` are supported as :class:`dpnp.ndarray`. - Keyword argument `kwargs` is currently unsupported. - Sizes of input arrays are limited by `x1.size == 3 and x2.size == 3`. - Shapes of input arrays are limited by `x1.shape == (3,) and x2.shape == (3,)`. - Otherwise the function will be executed sequentially on CPU. - Input array data types are limited by supported DPNP :ref:`Data types`. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + First input array. + b : {dpnp.ndarray, usm_ndarray} + Second input array. + axisa : int, optional + Axis of `a` that defines the vector(s). By default, the last axis. + axisb : int, optional + Axis of `b` that defines the vector(s). By default, the last axis. + axisc : int, optional + Axis of `c` containing the cross product vector(s). Ignored if + both input vectors have dimension 2, as the return is scalar. + By default, the last axis. + axis : {int, None}, optional + If defined, the axis of `a`, `b` and `c` that defines the vector(s) + and cross product(s). Overrides `axisa`, `axisb` and `axisc`. + + Returns + ------- + out : dpnp.ndarray + Vector cross product(s). + + See Also + -------- + :obj:`dpnp.inner` : Inner product. + :obj:`dpnp.outer` : Outer product. Examples -------- + Vector cross-product. + >>> import dpnp as np - >>> x = [1, 2, 3] - >>> y = [4, 5, 6] - >>> result = np.cross(x, y) - >>> [x for x in result] - [-3, 6, -3] + >>> x = np.array([1, 2, 3]) + >>> y = np.array([4, 5, 6]) + >>> np.cross(x, y) + array([-3, 6, -3]) + + One vector with dimension 2. + + >>> x = np.array([1, 2]) + >>> y = np.array([4, 5, 6]) + >>> np.cross(x, y) + array([12, -6, -3]) + + Equivalently: + + >>> x = np.array([1, 2, 0]) + >>> y = np.array([4, 5, 6]) + >>> np.cross(x, y) + array([12, -6, -3]) + + Both vectors with dimension 2. + + >>> x = np.array([1, 2]) + >>> y = np.array([4, 5]) + >>> np.cross(x, y) + array(-3) + + Multiple vector cross-products. Note that the direction of the cross + product vector is defined by the *right-hand rule*. + + >>> x = np.array([[1, 2, 3], [4, 5, 6]]) + >>> y = np.array([[4, 5, 6], [1, 2, 3]]) + >>> np.cross(x, y) + array([[-3, 6, -3], + [ 3, -6, 3]]) + + The orientation of `c` can be changed using the `axisc` keyword. + + >>> np.cross(x, y, axisc=0) + array([[-3, 3], + [ 6, -6], + [-3, 3]]) + + Change the vector definition of `x` and `y` using `axisa` and `axisb`. + + >>> x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> y = np.array([[7, 8, 9], [4, 5, 6], [1, 2, 3]]) + >>> np.cross(x, y) + array([[ -6, 12, -6], + [ 0, 0, 0], + [ 6, -12, 6]]) + >>> np.cross(x, y, axisa=0, axisb=0) + array([[-24, 48, -24], + [-30, 60, -30], + [-36, 72, -36]]) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - x2_desc = dpnp.get_dpnp_descriptor(x2, copy_when_nondefault_queue=False) + if axis is not None: + if not isinstance(axis, int): + raise TypeError(f"axis should be an integer but got, {type(axis)}.") + axisa, axisb, axisc = (axis,) * 3 + dpnp.check_supported_arrays_type(a, b) + # Check axisa and axisb are within bounds + axisa = normalize_axis_index(axisa, a.ndim, msg_prefix="axisa") + axisb = normalize_axis_index(axisb, b.ndim, msg_prefix="axisb") + + # Move working axis to the end of the shape + a = dpnp.moveaxis(a, axisa, -1) + b = dpnp.moveaxis(b, axisb, -1) + if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): + raise ValueError( + "Incompatible vector dimensions for cross product\n" + "(the dimension of vector used in cross product must be 2 or 3)" + ) - if x1_desc and x2_desc: - if x1_desc.size != 3 or x2_desc.size != 3: - pass - elif x1_desc.shape != (3,) or x2_desc.shape != (3,): - pass - elif axisa != -1: - pass - elif axisb != -1: - pass - elif axisc != -1: - pass - elif axis is not None: - pass - else: - return dpnp_cross(x1_desc, x2_desc).get_pyobj() + # Modify the shape of input arrays if necessary + a_shape = a.shape + b_shape = b.shape + # TODO: replace with dpnp.broadcast_shapes once implemented + res_shape = numpy.broadcast_shapes(a_shape[:-1], b_shape[:-1]) + if a_shape[:-1] != res_shape: + a = dpnp.broadcast_to(a, res_shape + (a_shape[-1],)) + a_shape = a.shape + if b_shape[:-1] != res_shape: + b = dpnp.broadcast_to(b, res_shape + (b_shape[-1],)) + b_shape = b.shape + + if a_shape[-1] == 3 or b_shape[-1] == 3: + res_shape += (3,) + # Check axisc is within bounds + axisc = normalize_axis_index(axisc, len(res_shape), msg_prefix="axisc") + # Create the output array + dtype = dpnp.result_type(a, b) + res_usm_type, exec_q = get_usm_allocations([a, b]) + cp = dpnp.empty( + res_shape, dtype=dtype, sycl_queue=exec_q, usm_type=res_usm_type + ) - return call_origin(numpy.cross, x1, x2, axisa, axisb, axisc, axis) + # recast arrays as dtype + a = a.astype(dtype, copy=False) + b = b.astype(dtype, copy=False) + + cp = dpnp_cross(a, b, cp, exec_q) + if a_shape[-1] == 2 and b_shape[-1] == 2: + return cp + else: + return dpnp.moveaxis(cp, -1, axisc) def cumprod(x1, **kwargs): diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index bb960b7f4b4..ff7b7a05972 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -25,6 +25,7 @@ import dpctl import dpctl.tensor as dpt +import dpctl.tensor._tensor_elementwise_impl as tei import dpctl.tensor._tensor_impl as ti import numpy from numpy.core.numeric import normalize_axis_tuple @@ -34,7 +35,7 @@ from dpnp.dpnp_array import dpnp_array from dpnp.dpnp_utils import get_usm_allocations -__all__ = ["dpnp_dot", "dpnp_matmul"] +__all__ = ["dpnp_cross", "dpnp_dot", "dpnp_matmul"] def _create_result_array(x1, x2, out, shape, dtype, usm_type, sycl_queue): @@ -320,6 +321,161 @@ def _validate_internal(axes, i, ndim): return axes +def dpnp_cross(a, b, cp, exec_q): + """Return the cross product of two (arrays of) vectors.""" + + # create local aliases for readability + a0 = a[..., 0] + a1 = a[..., 1] + if a.shape[-1] == 3: + a2 = a[..., 2] + b0 = b[..., 0] + b1 = b[..., 1] + if b.shape[-1] == 3: + b2 = b[..., 2] + if cp.ndim != 0 and cp.shape[-1] == 3: + cp0 = cp[..., 0] + cp1 = cp[..., 1] + cp2 = cp[..., 2] + + host_events = [] + if a.shape[-1] == 2: + if b.shape[-1] == 2: + # a0 * b1 - a1 * b0 + cp_usm = dpnp.get_usm_ndarray(cp) + ht_ev1, dev_ev1 = tei._multiply( + dpnp.get_usm_ndarray(a0), + dpnp.get_usm_ndarray(b1), + cp_usm, + exec_q, + ) + host_events.append(ht_ev1) + tmp = dpt.empty_like(cp_usm) + ht_ev2, dev_ev2 = tei._multiply( + dpnp.get_usm_ndarray(a1), dpnp.get_usm_ndarray(b0), tmp, exec_q + ) + host_events.append(ht_ev2) + ht_ev3, _ = tei._subtract_inplace( + cp_usm, tmp, exec_q, [dev_ev1, dev_ev2] + ) + host_events.append(ht_ev3) + else: + assert b.shape[-1] == 3 + # cp0 = a1 * b2 - 0 (a2 = 0) + # cp1 = 0 - a0 * b2 (a2 = 0) + # cp2 = a0 * b1 - a1 * b0 + cp1_usm = dpnp.get_usm_ndarray(cp1) + cp2_usm = dpnp.get_usm_ndarray(cp2) + a1_usm = dpnp.get_usm_ndarray(a1) + b2_usm = dpnp.get_usm_ndarray(b2) + ht_ev1, _ = tei._multiply( + a1_usm, b2_usm, dpnp.get_usm_ndarray(cp0), exec_q + ) + host_events.append(ht_ev1) + ht_ev2, dev_ev2 = tei._multiply( + dpnp.get_usm_ndarray(a0), b2_usm, cp1_usm, exec_q + ) + host_events.append(ht_ev2) + ht_ev3, _ = tei._negative(cp1_usm, cp1_usm, exec_q, [dev_ev2]) + host_events.append(ht_ev3) + ht_ev4, dev_ev4 = tei._multiply( + dpnp.get_usm_ndarray(a0), + dpnp.get_usm_ndarray(b1), + cp2_usm, + exec_q, + ) + host_events.append(ht_ev4) + tmp = dpt.empty_like(cp2_usm) + ht_ev5, dev_ev5 = tei._multiply( + a1_usm, dpnp.get_usm_ndarray(b0), tmp, exec_q + ) + host_events.append(ht_ev5) + ht_ev6, _ = tei._subtract_inplace( + cp2_usm, tmp, exec_q, [dev_ev4, dev_ev5] + ) + host_events.append(ht_ev6) + else: + assert a.shape[-1] == 3 + if b.shape[-1] == 3: + # cp0 = a1 * b2 - a2 * b1 + # cp1 = a2 * b0 - a0 * b2 + # cp2 = a0 * b1 - a1 * b0 + cp0_usm = dpnp.get_usm_ndarray(cp0) + cp1_usm = dpnp.get_usm_ndarray(cp1) + cp2_usm = dpnp.get_usm_ndarray(cp2) + a0_usm = dpnp.get_usm_ndarray(a0) + a1_usm = dpnp.get_usm_ndarray(a1) + a2_usm = dpnp.get_usm_ndarray(a2) + b0_usm = dpnp.get_usm_ndarray(b0) + b1_usm = dpnp.get_usm_ndarray(b1) + b2_usm = dpnp.get_usm_ndarray(b2) + ht_ev1, dev_ev1 = tei._multiply(a1_usm, b2_usm, cp0_usm, exec_q) + host_events.append(ht_ev1) + tmp = dpt.empty_like(cp0_usm) + ht_ev2, dev_ev2 = tei._multiply(a2_usm, b1_usm, tmp, exec_q) + host_events.append(ht_ev2) + ht_ev3, dev_ev3 = tei._subtract_inplace( + cp0_usm, tmp, exec_q, [dev_ev1, dev_ev2] + ) + host_events.append(ht_ev3) + ht_ev4, dev_ev4 = tei._multiply(a2_usm, b0_usm, cp1_usm, exec_q) + host_events.append(ht_ev4) + ht_ev5, dev_ev5 = tei._multiply( + a0_usm, b2_usm, tmp, exec_q, [dev_ev3] + ) + host_events.append(ht_ev5) + ht_ev6, dev_ev6 = tei._subtract_inplace( + cp1_usm, tmp, exec_q, [dev_ev4, dev_ev5] + ) + host_events.append(ht_ev6) + ht_ev7, dev_ev7 = tei._multiply(a0_usm, b1_usm, cp2_usm, exec_q) + host_events.append(ht_ev7) + ht_ev8, dev_ev8 = tei._multiply( + a1_usm, b0_usm, tmp, exec_q, [dev_ev6] + ) + host_events.append(ht_ev8) + ht_ev9, _ = tei._subtract_inplace( + cp2_usm, tmp, exec_q, [dev_ev7, dev_ev8] + ) + host_events.append(ht_ev9) + else: + assert b.shape[-1] == 2 + # cp0 = 0 - a2 * b1 (b2 = 0) + # cp1 = a2 * b0 - 0 (b2 = 0) + # cp2 = a0 * b1 - a1 * b0 + cp0_usm = dpnp.get_usm_ndarray(cp0) + cp2_usm = dpnp.get_usm_ndarray(cp2) + a2_usm = dpnp.get_usm_ndarray(a2) + b1_usm = dpnp.get_usm_ndarray(b1) + ht_ev1, dev_ev1 = tei._multiply(a2_usm, b1_usm, cp0_usm, exec_q) + host_events.append(ht_ev1) + ht_ev2, _ = tei._negative(cp0_usm, cp0_usm, exec_q, [dev_ev1]) + host_events.append(ht_ev2) + ht_ev3, _ = tei._multiply( + a2_usm, + dpnp.get_usm_ndarray(b0), + dpnp.get_usm_ndarray(cp1), + exec_q, + ) + host_events.append(ht_ev3) + ht_ev4, dev_ev4 = tei._multiply( + dpnp.get_usm_ndarray(a0), b1_usm, cp2_usm, exec_q + ) + host_events.append(ht_ev4) + tmp = dpt.empty_like(cp2_usm) + ht_ev5, dev_ev5 = tei._multiply( + dpnp.get_usm_ndarray(a1), dpnp.get_usm_ndarray(b0), tmp, exec_q + ) + host_events.append(ht_ev5) + ht_ev6, _ = tei._subtract_inplace( + cp2_usm, tmp, exec_q, [dev_ev4, dev_ev5] + ) + host_events.append(ht_ev6) + + dpctl.SyclEvent.wait_for(host_events) + return cp + + def dpnp_dot(a, b, /, out=None, *, conjugate=False): """ Return the dot product of two arrays. diff --git a/tests/test_mathematical.py b/tests/test_mathematical.py index cbbab19333c..16c2375a803 100644 --- a/tests/test_mathematical.py +++ b/tests/test_mathematical.py @@ -1191,34 +1191,6 @@ def test_trapz_with_dx_params(self, y_array, dx): assert_array_equal(expected, result) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -class TestCross: - @pytest.mark.parametrize("axis", [None, 0], ids=["None", "0"]) - @pytest.mark.parametrize("axisc", [-1, 0], ids=["-1", "0"]) - @pytest.mark.parametrize("axisb", [-1, 0], ids=["-1", "0"]) - @pytest.mark.parametrize("axisa", [-1, 0], ids=["-1", "0"]) - @pytest.mark.parametrize( - "x1", - [[1, 2, 3], [1.0, 2.5, 6.0], [2, 4, 6]], - ids=["[1, 2, 3]", "[1., 2.5, 6.]", "[2, 4, 6]"], - ) - @pytest.mark.parametrize( - "x2", - [[4, 5, 6], [1.0, 5.0, 2.0], [6, 4, 3]], - ids=["[4, 5, 6]", "[1., 5., 2.]", "[6, 4, 3]"], - ) - def test_cross_3x3(self, x1, x2, axisa, axisb, axisc, axis): - np_x1 = numpy.array(x1) - dpnp_x1 = dpnp.array(x1) - - np_x2 = numpy.array(x2) - dpnp_x2 = dpnp.array(x2) - - result = dpnp.cross(dpnp_x1, dpnp_x2, axisa, axisb, axisc, axis) - expected = numpy.cross(np_x1, np_x2, axisa, axisb, axisc, axis) - assert_array_equal(expected, result) - - class TestGradient: @pytest.mark.parametrize( "array", [[2, 3, 6, 8, 4, 9], [3.0, 4.0, 7.5, 9.0], [2, 6, 8, 10]] diff --git a/tests/test_dot.py b/tests/test_product.py similarity index 66% rename from tests/test_dot.py rename to tests/test_product.py index e9065e071c0..2b45e825714 100644 --- a/tests/test_dot.py +++ b/tests/test_product.py @@ -8,6 +8,191 @@ from .helper import assert_dtype_allclose, get_all_dtypes, get_complex_dtypes +class TestCross: + def setup_method(self): + numpy.random.seed(42) + + @pytest.mark.parametrize("axis", [None, 0], ids=["None", "0"]) + @pytest.mark.parametrize("axisc", [-1, 0], ids=["-1", "0"]) + @pytest.mark.parametrize("axisb", [-1, 0], ids=["-1", "0"]) + @pytest.mark.parametrize("axisa", [-1, 0], ids=["-1", "0"]) + @pytest.mark.parametrize( + "x1", + [[1, 2, 3], [1.0, 2.5, 6.0], [2, 4, 6]], + ids=["[1, 2, 3]", "[1., 2.5, 6.]", "[2, 4, 6]"], + ) + @pytest.mark.parametrize( + "x2", + [[4, 5, 6], [1.0, 5.0, 2.0], [6, 4, 3]], + ids=["[4, 5, 6]", "[1., 5., 2.]", "[6, 4, 3]"], + ) + def test_cross_3x3(self, x1, x2, axisa, axisb, axisc, axis): + np_x1 = numpy.array(x1) + dpnp_x1 = dpnp.array(x1) + + np_x2 = numpy.array(x2) + dpnp_x2 = dpnp.array(x2) + + result = dpnp.cross(dpnp_x1, dpnp_x2, axisa, axisb, axisc, axis) + expected = numpy.cross(np_x1, np_x2, axisa, axisb, axisc, axis) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_complex=True) + ) + @pytest.mark.parametrize( + "shape1, shape2, axis_a, axis_b, axis_c", + [ + ((4, 2, 3, 5), (2, 4, 3, 5), 1, 0, -2), + ((2, 2, 4, 5), (2, 4, 3, 5), 1, 2, -1), + ((2, 3, 4, 5), (2, 4, 2, 5), 1, 2, -1), + ((2, 3, 4, 5), (2, 4, 3, 5), 1, 2, -1), + ((2, 3, 4, 5), (2, 4, 3, 5), -3, -2, 0), + ], + ) + def test_cross(self, dtype, shape1, shape2, axis_a, axis_b, axis_c): + a = numpy.array( + numpy.random.uniform(-5, 5, numpy.prod(shape1)), dtype=dtype + ).reshape(shape1) + b = numpy.array( + numpy.random.uniform(-5, 5, numpy.prod(shape2)), dtype=dtype + ).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.cross(ia, ib, axis_a, axis_b, axis_c) + expected = numpy.cross(a, b, axis_a, axis_b, axis_c) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize( + "shape1, shape2, axis_a, axis_b, axis_c", + [ + ((4, 2, 3, 5), (2, 4, 3, 5), 1, 0, -2), + ((2, 2, 4, 5), (2, 4, 3, 5), 1, 2, -1), + ((2, 3, 4, 5), (2, 4, 2, 5), 1, 2, -1), + ((2, 3, 4, 5), (2, 4, 3, 5), 1, 2, -1), + ((2, 3, 4, 5), (2, 4, 3, 5), -3, -2, 0), + ], + ) + def test_cross_complex(self, dtype, shape1, shape2, axis_a, axis_b, axis_c): + x11 = numpy.random.uniform(-5, 5, numpy.prod(shape1)) + x12 = numpy.random.uniform(-5, 5, numpy.prod(shape1)) + x21 = numpy.random.uniform(-5, 5, numpy.prod(shape2)) + x22 = numpy.random.uniform(-5, 5, numpy.prod(shape2)) + a = numpy.array(x11 + 1j * x12, dtype=dtype).reshape(shape1) + b = numpy.array(x21 + 1j * x22, dtype=dtype).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.cross(ia, ib, axis_a, axis_b, axis_c) + expected = numpy.cross(a, b, axis_a, axis_b, axis_c) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "shape1, shape2, axis", + [ + ((2, 3, 4, 5), (2, 3, 4, 5), 0), + ((2, 3, 4, 5), (2, 3, 4, 5), 1), + ], + ) + def test_cross_axis(self, dtype, shape1, shape2, axis): + a = numpy.array( + numpy.random.uniform(-5, 5, numpy.prod(shape1)), dtype=dtype + ).reshape(shape1) + b = numpy.array( + numpy.random.uniform(-5, 5, numpy.prod(shape2)), dtype=dtype + ).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.cross(ia, ib, axis=axis) + expected = numpy.cross(a, b, axis=axis) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype1", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize("dtype2", get_all_dtypes(no_bool=True)) + def test_cross_input_dtype_matrix(self, dtype1, dtype2): + a = numpy.array(numpy.random.uniform(-5, 5, 3), dtype=dtype1) + b = numpy.array(numpy.random.uniform(-5, 5, 3), dtype=dtype2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.cross(ia, ib) + expected = numpy.cross(a, b) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_complex=True) + ) + @pytest.mark.parametrize( + "shape1, shape2, axis_a, axis_b, axis_c", + [ + ((4, 2, 1, 5), (2, 4, 3, 5), 1, 0, -2), + ((2, 2, 4, 5), (2, 4, 3, 1), 1, 2, -1), + ((2, 3, 4, 1), (2, 4, 2, 5), 1, 2, -1), + ((1, 3, 4, 5), (2, 4, 3, 5), 1, 2, -1), + ((2, 3, 4, 5), (1, 1, 3, 1), -3, -2, 0), + ], + ) + def test_cross_broadcast( + self, dtype, shape1, shape2, axis_a, axis_b, axis_c + ): + a = numpy.array( + numpy.random.uniform(-5, 5, numpy.prod(shape1)), dtype=dtype + ).reshape(shape1) + b = numpy.array( + numpy.random.uniform(-5, 5, numpy.prod(shape2)), dtype=dtype + ).reshape(shape2) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.cross(ia, ib, axis_a, axis_b, axis_c) + expected = numpy.cross(a, b, axis_a, axis_b, axis_c) + assert_dtype_allclose(result, expected) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_cross_strided(self, dtype): + a = numpy.arange(1, 10, dtype=dtype) + b = numpy.arange(1, 10, dtype=dtype) + ia = dpnp.array(a) + ib = dpnp.array(b) + + result = dpnp.cross(ia[::3], ib[::3]) + expected = numpy.cross(a[::3], b[::3]) + assert_dtype_allclose(result, expected) + + a = numpy.arange(1, 4, dtype=dtype) + b = numpy.arange(1, 4, dtype=dtype) + ia = dpnp.array(a) + ib = dpnp.array(b) + result = dpnp.cross(ia, ib[::-1]) + expected = numpy.cross(a, b[::-1]) + assert_dtype_allclose(result, expected) + + a = numpy.arange(1, 7, dtype=dtype) + b = numpy.arange(1, 7, dtype=dtype) + ia = dpnp.array(a) + ib = dpnp.array(b) + result = dpnp.cross(ia[::-2], ib[::-2]) + expected = numpy.cross(a[::-2], b[::-2]) + assert_dtype_allclose(result, expected) + + def test_cross_error(self): + a = dpnp.arange(3) + b = dpnp.arange(4) + # Incompatible vector dimensions + with pytest.raises(ValueError): + dpnp.cross(a, b) + + a = dpnp.arange(3) + b = dpnp.arange(4) + # axis should be an integer + with pytest.raises(TypeError): + dpnp.cross(a, b, axis=0.0) + + class TestDot: def setup_method(self): numpy.random.seed(42) @@ -49,17 +234,17 @@ def test_dot_scalar(self, dtype): @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) @pytest.mark.parametrize( - "array_info", + "shape_pair", [ - (1, 10, (), (10,)), - (10, 1, (10,), ()), - (1, 1, (), ()), - (10, 10, (10,), (10,)), - (12, 6, (4, 3), (3, 2)), - (12, 3, (4, 3), (3,)), - (60, 3, (5, 4, 3), (3,)), - (4, 8, (4,), (4, 2)), - (60, 48, (5, 3, 4), (6, 4, 2)), + ((), (10,)), + ((10,), ()), + ((), ()), + ((10,), (10,)), + ((4, 3), (3, 2)), + ((4, 3), (3,)), + ((5, 4, 3), (3,)), + ((4,), (4, 2)), + ((5, 3, 4), (6, 4, 2)), ], ids=[ "0d_1d", @@ -73,8 +258,10 @@ def test_dot_scalar(self, dtype): "3d_3d", ], ) - def test_dot(self, dtype, array_info): - size1, size2, shape1, shape2 = array_info + def test_dot(self, dtype, shape_pair): + shape1, shape2 = shape_pair + size1 = numpy.prod(shape1, dtype=int) + size2 = numpy.prod(shape2, dtype=int) a = numpy.array( numpy.random.uniform(-5, 5, size1), dtype=dtype ).reshape(shape1) @@ -90,17 +277,17 @@ def test_dot(self, dtype, array_info): @pytest.mark.parametrize("dtype", get_complex_dtypes()) @pytest.mark.parametrize( - "array_info", + "shape_pair", [ - (1, 10, (), (10,)), - (10, 1, (10,), ()), - (1, 1, (), ()), - (10, 10, (10,), (10,)), - (12, 6, (4, 3), (3, 2)), - (12, 3, (4, 3), (3,)), - (60, 3, (5, 4, 3), (3,)), - (4, 8, (4,), (4, 2)), - (60, 48, (5, 3, 4), (6, 4, 2)), + ((), (10,)), + ((10,), ()), + ((), ()), + ((10,), (10,)), + ((4, 3), (3, 2)), + ((4, 3), (3,)), + ((5, 4, 3), (3,)), + ((4,), (4, 2)), + ((5, 3, 4), (6, 4, 2)), ], ids=[ "0d_1d", @@ -114,8 +301,10 @@ def test_dot(self, dtype, array_info): "3d_3d", ], ) - def test_dot_complex(self, dtype, array_info): - size1, size2, shape1, shape2 = array_info + def test_dot_complex(self, dtype, shape_pair): + shape1, shape2 = shape_pair + size1 = numpy.prod(shape1, dtype=int) + size2 = numpy.prod(shape2, dtype=int) x11 = numpy.random.uniform(-5, 5, size1) x12 = numpy.random.uniform(-5, 5, size1) x21 = numpy.random.uniform(-5, 5, size2) @@ -131,17 +320,17 @@ def test_dot_complex(self, dtype, array_info): @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( - "array_info", + "shape_pair", [ - (1, 10, (), (10,)), - (10, 1, (10,), ()), - (1, 1, (), ()), - (10, 10, (10,), (10,)), - (12, 6, (4, 3), (3, 2)), - (12, 3, (4, 3), (3,)), - (60, 3, (5, 4, 3), (3,)), - (4, 8, (4,), (4, 2)), - (60, 48, (5, 3, 4), (6, 4, 2)), + ((), (10,)), + ((10,), ()), + ((), ()), + ((10,), (10,)), + ((4, 3), (3, 2)), + ((4, 3), (3,)), + ((5, 4, 3), (3,)), + ((4,), (4, 2)), + ((5, 3, 4), (6, 4, 2)), ], ids=[ "0d_1d", @@ -155,8 +344,10 @@ def test_dot_complex(self, dtype, array_info): "3d_3d", ], ) - def test_dot_ndarray(self, dtype, array_info): - size1, size2, shape1, shape2 = array_info + def test_dot_ndarray(self, dtype, shape_pair): + shape1, shape2 = shape_pair + size1 = numpy.prod(shape1, dtype=int) + size2 = numpy.prod(shape2, dtype=int) a = numpy.array( numpy.random.uniform(-5, 5, size1), dtype=dtype ).reshape(shape1) @@ -210,17 +401,17 @@ def test_dot_out_scalar(self, dtype): @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( - "array_info", + "shape_pair", [ - (1, 10, (), (10,), (10,)), - (10, 1, (10,), (), (10,)), - (1, 1, (), (), ()), - (10, 10, (10,), (10,), ()), - (12, 6, (4, 3), (3, 2), (4, 2)), - (12, 3, (4, 3), (3,), (4,)), - (60, 3, (5, 4, 3), (3,), (5, 4)), - (4, 8, (4,), (4, 2), (2,)), - (60, 48, (5, 3, 4), (6, 4, 2), (5, 3, 6, 2)), + ((), (10,), (10,)), + ((10,), (), (10,)), + ((), (), ()), + ((10,), (10,), ()), + ((4, 3), (3, 2), (4, 2)), + ((4, 3), (3,), (4,)), + ((5, 4, 3), (3,), (5, 4)), + ((4,), (4, 2), (2,)), + ((5, 3, 4), (6, 4, 2), (5, 3, 6, 2)), ], ids=[ "0d_1d", @@ -234,8 +425,10 @@ def test_dot_out_scalar(self, dtype): "3d_3d", ], ) - def test_dot_out(self, dtype, array_info): - size1, size2, shape1, shape2, out_shape = array_info + def test_dot_out(self, dtype, shape_pair): + shape1, shape2, out_shape = shape_pair + size1 = numpy.prod(shape1, dtype=int) + size2 = numpy.prod(shape2, dtype=int) a = numpy.array( numpy.random.uniform(-5, 5, size1), dtype=dtype ).reshape(shape1) @@ -428,6 +621,7 @@ def test_tensordot_axes(self, dtype, axes): ia = dpnp.array(a) ib = dpnp.array(b) + print(a.dtype, ia.dtype) result = dpnp.tensordot(ia, ib, axes=axes) expected = numpy.tensordot(a, b, axes=axes) assert_dtype_allclose(result, expected) @@ -521,15 +715,15 @@ def test_vdot_scalar(self, dtype): @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) @pytest.mark.parametrize( - "array_info", + "shape_pair", [ - (1, 1, (), ()), - (10, 10, (10,), (10,)), - (12, 12, (4, 3), (3, 4)), - (12, 12, (4, 3), (12,)), - (60, 60, (5, 4, 3), (60,)), - (8, 8, (8,), (4, 2)), - (60, 60, (5, 3, 4), (3, 4, 5)), + ((), ()), + ((10,), (10,)), + ((4, 3), (3, 4)), + ((4, 3), (12,)), + ((5, 4, 3), (60,)), + ((8,), (4, 2)), + ((5, 3, 4), (3, 4, 5)), ], ids=[ "0d_0d", @@ -541,8 +735,10 @@ def test_vdot_scalar(self, dtype): "3d_3d", ], ) - def test_vdot(self, dtype, array_info): - size1, size2, shape1, shape2 = array_info + def test_vdot(self, dtype, shape_pair): + shape1, shape2 = shape_pair + size1 = numpy.prod(shape1, dtype=int) + size2 = numpy.prod(shape2, dtype=int) a = numpy.array( numpy.random.uniform(-5, 5, size1), dtype=dtype ).reshape(shape1) @@ -558,15 +754,15 @@ def test_vdot(self, dtype, array_info): @pytest.mark.parametrize("dtype", get_complex_dtypes()) @pytest.mark.parametrize( - "array_info", + "shape_pair", [ - (1, 1, (), ()), - (10, 10, (10,), (10,)), - (12, 12, (4, 3), (3, 4)), - (12, 12, (4, 3), (12,)), - (60, 60, (5, 4, 3), (60,)), - (8, 8, (8,), (4, 2)), - (60, 60, (5, 3, 4), (3, 4, 5)), + ((), ()), + ((10,), (10,)), + ((4, 3), (3, 4)), + ((4, 3), (12,)), + ((5, 4, 3), (60,)), + ((8,), (4, 2)), + ((5, 3, 4), (3, 4, 5)), ], ids=[ "0d_0d", @@ -578,8 +774,10 @@ def test_vdot(self, dtype, array_info): "3d_3d", ], ) - def test_vdot_complex(self, dtype, array_info): - size1, size2, shape1, shape2 = array_info + def test_vdot_complex(self, dtype, shape_pair): + shape1, shape2 = shape_pair + size1 = numpy.prod(shape1, dtype=int) + size2 = numpy.prod(shape2, dtype=int) x11 = numpy.random.uniform(-5, 5, size1) x12 = numpy.random.uniform(-5, 5, size1) x21 = numpy.random.uniform(-5, 5, size2) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index fc9993642eb..26c7e9e3057 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -492,6 +492,7 @@ def test_1in_1out(func, data, usm_type): ), pytest.param("arctan2", [[-1, +1, +1, -1]], [[-1, -1, +1, +1]]), pytest.param("copysign", [0.0, 1.0, 2.0], [-1.0, 0.0, 1.0]), + pytest.param("cross", [1.0, 2.0, 3.0], [4.0, 5.0, 6.0]), # dpnp.dot has 3 different implementations based on input arrays dtype # checking all of them pytest.param("dot", [3.0, 4.0, 5.0], [1.0, 2.0, 3.0]), diff --git a/tests/third_party/cupy/linalg_tests/test_product.py b/tests/third_party/cupy/linalg_tests/test_product.py index e59b30dcd6e..1993165f86e 100644 --- a/tests/third_party/cupy/linalg_tests/test_product.py +++ b/tests/third_party/cupy/linalg_tests/test_product.py @@ -101,14 +101,9 @@ def test_dot_with_out(self, xp, dtype_a, dtype_b, dtype_c): } ) ) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") -@testing.gpu class TestCrossProduct(unittest.TestCase): @testing.for_all_dtypes_combination(["dtype_a", "dtype_b"]) - # TODO: remove 'contiguous_check=False' once fixed in dpnp.cross() - @testing.numpy_cupy_allclose( - type_check=has_support_aspect64(), contiguous_check=False - ) + @testing.numpy_cupy_allclose(type_check=has_support_aspect64()) def test_cross(self, xp, dtype_a, dtype_b): if dtype_a == dtype_b == numpy.bool_: # cross does not support bool-bool inputs.