From a7bcf9db9e8c7d29ab2ae60be8bd3b82a9dca068 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 3 May 2023 06:48:49 -0500 Subject: [PATCH 1/6] Rework transpose methods to call dpctl.tensor functions --- dpnp/backend/include/dpnp_iface_fptr.hpp | 1 - .../kernels/dpnp_krnl_manipulation.cpp | 24 +----- dpnp/dpnp_algo/dpnp_algo.pxd | 2 - dpnp/dpnp_algo/dpnp_algo_manipulation.pxi | 74 ----------------- dpnp/dpnp_array.py | 67 ++++++++++++---- dpnp/dpnp_iface_manipulation.py | 79 +++++++++++-------- dpnp/dpnp_iface_statistics.py | 3 +- tests/skipped_tests.tbl | 4 - tests/skipped_tests_gpu.tbl | 9 +-- .../cupy/linalg_tests/test_eigenvalue.py | 3 +- 10 files changed, 103 insertions(+), 163 deletions(-) diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 653471fd1b5..d9dd2dc6b40 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -360,7 +360,6 @@ enum class DPNPFuncName : size_t DPNP_FN_TANH, /**< Used in numpy.tanh() impl */ DPNP_FN_TANH_EXT, /**< Used in numpy.tanh() impl, requires extra parameters */ DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() impl */ - DPNP_FN_TRANSPOSE_EXT, /**< Used in numpy.transpose() impl, requires extra parameters */ DPNP_FN_TRACE, /**< Used in numpy.trace() impl */ DPNP_FN_TRACE_EXT, /**< Used in numpy.trace() impl, requires extra parameters */ DPNP_FN_TRAPZ, /**< Used in numpy.trapz() impl */ diff --git a/dpnp/backend/kernels/dpnp_krnl_manipulation.cpp b/dpnp/backend/kernels/dpnp_krnl_manipulation.cpp index 8a122dbf728..2fc7832b6ba 100644 --- a/dpnp/backend/kernels/dpnp_krnl_manipulation.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_manipulation.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2016-2020, Intel Corporation +// Copyright (c) 2016-2023, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without @@ -211,6 +211,7 @@ void dpnp_elemwise_transpose_c(void* array1_in, size, dep_event_vec_ref); DPCTLEvent_WaitAndThrow(event_ref); + DPCTLEvent_Delete(event_ref); } template @@ -222,17 +223,6 @@ void (*dpnp_elemwise_transpose_default_c)(void*, void*, size_t) = dpnp_elemwise_transpose_c<_DataType>; -template -DPCTLSyclEventRef (*dpnp_elemwise_transpose_ext_c)(DPCTLSyclQueueRef, - void*, - const shape_elem_type*, - const shape_elem_type*, - const shape_elem_type*, - size_t, - void*, - size_t, - const DPCTLEventVectorRef) = dpnp_elemwise_transpose_c<_DataType>; - void func_map_init_manipulation(func_map_t& fmap) { fmap[DPNPFuncName::DPNP_FN_REPEAT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_repeat_default_c}; @@ -253,15 +243,5 @@ void func_map_init_manipulation(func_map_t& fmap) (void*)dpnp_elemwise_transpose_default_c}; fmap[DPNPFuncName::DPNP_FN_TRANSPOSE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_elemwise_transpose_default_c}; - - fmap[DPNPFuncName::DPNP_FN_TRANSPOSE_EXT][eft_INT][eft_INT] = {eft_INT, - (void*)dpnp_elemwise_transpose_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TRANSPOSE_EXT][eft_LNG][eft_LNG] = {eft_LNG, - (void*)dpnp_elemwise_transpose_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TRANSPOSE_EXT][eft_FLT][eft_FLT] = {eft_FLT, - (void*)dpnp_elemwise_transpose_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TRANSPOSE_EXT][eft_DBL][eft_DBL] = {eft_DBL, - (void*)dpnp_elemwise_transpose_ext_c}; - return; } diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 09af5667f8c..d6b5429669f 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -339,7 +339,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_TRACE DPNP_FN_TRACE_EXT DPNP_FN_TRANSPOSE - DPNP_FN_TRANSPOSE_EXT DPNP_FN_TRAPZ DPNP_FN_TRAPZ_EXT DPNP_FN_TRI @@ -555,7 +554,6 @@ cpdef dpnp_descriptor dpnp_subtract(dpnp_descriptor x1_obj, dpnp_descriptor x2_o Array manipulation routines """ cpdef dpnp_descriptor dpnp_repeat(dpnp_descriptor array1, repeats, axes=*) -cpdef dpnp_descriptor dpnp_transpose(dpnp_descriptor array1, axes=*) """ diff --git a/dpnp/dpnp_algo/dpnp_algo_manipulation.pxi b/dpnp/dpnp_algo/dpnp_algo_manipulation.pxi index 3e27af363c3..9176092b7c0 100644 --- a/dpnp/dpnp_algo/dpnp_algo_manipulation.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_manipulation.pxi @@ -42,21 +42,11 @@ __all__ += [ "dpnp_expand_dims", "dpnp_repeat", "dpnp_reshape", - "dpnp_transpose", "dpnp_squeeze", ] # C function pointer to the C library template functions -ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_custom_elemwise_transpose_1in_1out_t)(c_dpctl.DPCTLSyclQueueRef, - void * , - shape_elem_type * , - shape_elem_type * , - shape_elem_type * , - size_t, - void * , - size_t, - const c_dpctl.DPCTLEventVectorRef) ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_repeat_t)(c_dpctl.DPCTLSyclQueueRef, const void *, void * , const size_t , const size_t, const c_dpctl.DPCTLEventVectorRef) @@ -232,70 +222,6 @@ cpdef utils.dpnp_descriptor dpnp_reshape(utils.dpnp_descriptor array1, newshape, copy_when_nondefault_queue=False) -cpdef utils.dpnp_descriptor dpnp_transpose(utils.dpnp_descriptor array1, axes=None): - cdef shape_type_c input_shape = array1.shape - cdef size_t input_shape_size = array1.ndim - cdef shape_type_c result_shape = shape_type_c(input_shape_size, 1) - - cdef shape_type_c permute_axes - if axes is None: - """ - template to do transpose a tensor - input_shape=[2, 3, 4] - permute_axes=[2, 1, 0] - after application `permute_axes` to `input_shape` result: - result_shape=[4, 3, 2] - - 'do nothing' axes variable is `permute_axes=[0, 1, 2]` - - test: pytest tests/third_party/cupy/manipulation_tests/test_transpose.py::TestTranspose::test_external_transpose_all - """ - permute_axes = list(reversed([i for i in range(input_shape_size)])) - else: - permute_axes = utils.normalize_axis(axes, input_shape_size) - - for i in range(input_shape_size): - """ construct output shape """ - result_shape[i] = input_shape[permute_axes[i]] - - # convert string type names (array.dtype) to C enum DPNPFuncType - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype) - - # get the FPTR data structure - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRANSPOSE_EXT, param1_type, param1_type) - - array1_obj = array1.get_array() - - # ceate result array with type given by FPTR data - cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, - kernel_data.return_type, - None, - device=array1_obj.sycl_device, - usm_type=array1_obj.usm_type, - sycl_queue=array1_obj.sycl_queue) - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef fptr_custom_elemwise_transpose_1in_1out_t func = kernel_data.ptr - # call FPTR function - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - array1.get_data(), - input_shape.data(), - result_shape.data(), - permute_axes.data(), - input_shape_size, - result.get_data(), - array1.size, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result - - cpdef utils.dpnp_descriptor dpnp_squeeze(utils.dpnp_descriptor in_array, axis): cdef shape_type_c shape_list if axis is None: diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 5741ea0fa14..9d8d4ecacc3 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -99,15 +99,8 @@ def get_array(self): @property def T(self): - """Shape-reversed view of the array. - - If ndim < 2, then this is just a reference to the array itself. - - """ - if self.ndim < 2: - return self - else: - return dpnp.transpose(self) + """View of the transposed array.""" + return self.transpose() def to_device(self, target_device): """ @@ -1000,15 +993,61 @@ def take(self, indices, axis=None, out=None, mode='raise'): def transpose(self, *axes): """ - Returns a view of the array with axes permuted. + Returns a view of the array with axes transposed. - .. seealso:: - :obj:`dpnp.transpose` for full documentation, - :meth:`numpy.ndarray.reshape` + For full documentation refer to :obj:`numpy.ndarray.transpose`. + + Returns + ------- + y : dpnp.ndarray + View of the array with its axes suitably permuted. + See Also + -------- + :obj:`dpnp.transpose` : Equivalent function. + :obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed. + :obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data. + + Examples + -------- + >>> import dpnp as dp + >>> a = dp.array([[1, 2], [3, 4]]) + >>> a + array([[1, 2], + [3, 4]]) + >>> a.transpose() + array([[1, 3], + [2, 4]]) + >>> a.transpose((1, 0)) + array([[1, 3], + [2, 4]]) + >>> a.transpose(1, 0) + array([[1, 3], + [2, 4]]) + + >>> a = dp.array([1, 2, 3, 4]) + >>> a + array([1, 2, 3, 4]) + >>> a.transpose() + array([1, 2, 3, 4]) + """ - return dpnp.transpose(self, axes) + ndim = self.ndim + if ndim < 2: + return self + + res = self.__new__(dpnp_array) + if ndim == 2: + res._array_obj = self._array_obj.T + else: + if len(axes) == 0 or len(axes) == 1 and axes[0] is None: + # self.transpose().shape == self.shape[::-1] + # self.transpose(None).shape == self.shape[::-1] + axes = tuple((ndim - x - 1) for x in range(ndim)) + + res._array_obj = dpt.permute_dims(self._array_obj, axes) + return res def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): """ diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index 567661bdb57..4483d4f3faa 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -678,54 +678,65 @@ def swapaxes(x1, axis1, axis2): return call_origin(numpy.swapaxes, x1, axis1, axis2) -def transpose(x1, axes=None): +def transpose(a, axes=None): """ - Reverse or permute the axes of an array; returns the modified array. + Returns an array with axes transposed. For full documentation refer to :obj:`numpy.transpose`. + Returns + ------- + y : dpnp.ndarray + `a` with its axes permuted. A view is returned whenever possible. + Limitations ----------- - Input array is supported as :obj:`dpnp.ndarray`. - Otherwise the function will be executed sequentially on CPU. - Value of the parameter ``axes`` likely to be replaced with ``None``. - Input array data types are limited by supported DPNP :ref:`Data types`. + Input array is supported as either :class:`dpnp.ndarray` + or :class:`dpctl.tensor.usm_ndarray`. See Also -------- + :obj:`dpnp.ndarray.transpose` : Equivalent method. :obj:`dpnp.moveaxis` : Move array axes to new positions. :obj:`dpnp.argsort` : Returns the indices that would sort an array. Examples -------- - >>> import dpnp as np - >>> x = np.arange(4).reshape((2,2)) - >>> x.shape - (2, 2) - >>> [i for i in x] - [0, 1, 2, 3] - >>> out = np.transpose(x) - >>> out.shape - (2, 2) - >>> [i for i in out] - [0, 2, 1, 3] - - """ - - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc: - if axes is not None: - if not any(axes): - """ - pytest tests/third_party/cupy/manipulation_tests/test_transpose.py - """ - axes = None - - result = dpnp_transpose(x1_desc, axes).get_pyobj() - - return result - - return call_origin(numpy.transpose, x1, axes=axes) + >>> import dpnp as dp + >>> a = dp.array([[1, 2], [3, 4]]) + >>> a + array([[1, 2], + [3, 4]]) + >>> dp.transpose(a) + array([[1, 3], + [2, 4]]) + + >>> a = dp.array([1, 2, 3, 4]) + >>> a + array([1, 2, 3, 4]) + >>> dp.transpose(a) + array([1, 2, 3, 4]) + + >>> a = dp.ones((1, 2, 3)) + >>> dp.transpose(a, (1, 0, 2)).shape + (2, 1, 3) + + >>> a = dp.ones((2, 3, 4, 5)) + >>> dp.transpose(a).shape + (5, 4, 3, 2) + + """ + + if isinstance(a, dpnp_array): + array = a + elif isinstance(a, dpt.usm_ndarray): + array = dpnp_array._create_from_usm_ndarray(a.get_array()) + else: + raise TypeError("An array must be any of supported type, but got {}".format(type(a))) + + if axes is None: + return array.transpose() + return array.transpose(*axes) def unique(x1, **kwargs): diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index 966a7214269..35e48e1ea94 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -290,8 +290,7 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights= pass else: if not rowvar and x1.shape[0] != 1: - x1 = x1.get_array() if isinstance(x1, dpnp_array) else x1 - x1 = dpnp_array._create_from_usm_ndarray(x1.mT) + x1 = x1.T if not x1.dtype in (dpnp.float32, dpnp.float64): x1 = dpnp.astype(x1, dpnp.default_float_type(sycl_queue=x1.sycl_queue)) diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index a41b881ae7b..cc4bdd57c95 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -234,8 +234,6 @@ tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAn tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides_swapped tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_c_contiguous_no_copy tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_f_contiguous_no_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_transposed_fill -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_transposed_flatten tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten_copied tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy @@ -837,7 +835,6 @@ tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint_nega tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all2 tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_keepdims -tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed2 tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes2 @@ -886,7 +883,6 @@ tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_1 tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_9_{axis=0, func='nanprod', keepdims=True, shape=(2, 3, 4), transpose_axes=False}::test_nansum_all tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_9_{axis=0, func='nanprod', keepdims=True, shape=(2, 3, 4), transpose_axes=False}::test_nansum_axis_transposed tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all2 -tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed2 tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_1dim tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_1dim_with_discont tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_2dim_with_axis diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 36b17d5edbc..e1cfd00028d 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -142,7 +142,6 @@ tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsPois tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([[i, i] for i in x])] -tests/test_arraymanipulation.py::TestConcatenate::test_concatenate tests/test_histograms.py::TestHistogram::test_density tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), dpnp.int8)] @@ -376,8 +375,6 @@ tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAn tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides_swapped tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_c_contiguous_no_copy tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_f_contiguous_no_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_transposed_fill -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_transposed_flatten tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten_copied tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy @@ -783,15 +780,11 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumLarge_param_9_{opt tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_float tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1 -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_dot_vec2 tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_multidim_vdot -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_higher_order_inner tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes @@ -1005,7 +998,7 @@ tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint_negative tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_round_ tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_trunc -tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed + tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed2 tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes2 diff --git a/tests/third_party/cupy/linalg_tests/test_eigenvalue.py b/tests/third_party/cupy/linalg_tests/test_eigenvalue.py index fe577e32b28..1a5309e481c 100644 --- a/tests/third_party/cupy/linalg_tests/test_eigenvalue.py +++ b/tests/third_party/cupy/linalg_tests/test_eigenvalue.py @@ -8,8 +8,7 @@ def _get_hermitian(xp, a, UPLO): - # TODO: fix this, currently dpnp.transpose() doesn't support complex types - # and no dpnp_array.swapaxes() + # TODO: remove wrapping, but now there is no dpnp_array.swapaxes() a = _wrap_as_numpy_array(xp, a) _xp = numpy From 7c0d400674ccedcd2fbbea834cda9d0c28004ece Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 3 May 2023 11:11:17 -0500 Subject: [PATCH 2/6] Reuse dpctl.tensor.reshape --- dpnp/dpnp_array.py | 44 ++++++++------------ dpnp/dpnp_iface_manipulation.py | 73 +++++++++++++++++++++++++++++---- 2 files changed, 82 insertions(+), 35 deletions(-) diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 9d8d4ecacc3..0a9b56a610a 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -845,14 +845,21 @@ def prod(self, axis=None, dtype=None, out=None, keepdims=False, initial=None, wh # 'real', # 'repeat', - def reshape(self, d0, *dn, order=b'C'): + def reshape(self, *sh, **kwargs): """ Returns an array containing the same data with a new shape. - Refer to `dpnp.reshape` for full documentation. + For full documentation refer to :obj:`numpy.ndarray.reshape`. - .. seealso:: - :meth:`numpy.ndarray.reshape` + Returns + ------- + y : dpnp.ndarray + This will be a new view object if possible; + otherwise, it will be a copy. + + See Also + -------- + :obj:`dpnp.reshape` : Equivalent function. Notes ----- @@ -863,17 +870,9 @@ def reshape(self, d0, *dn, order=b'C'): """ - if dn: - if not isinstance(d0, int): - msg_tmpl = "'{}' object cannot be interpreted as an integer" - raise TypeError(msg_tmpl.format(type(d0).__name__)) - shape = [d0, *dn] - else: - shape = d0 - - shape_tup = dpnp.dpnp_utils._object_to_tuple(shape) - - return dpnp.reshape(self, shape_tup) + if len(sh) == 1: + sh = sh[0] + return dpnp.reshape(self, sh, **kwargs) # 'resize', @@ -915,14 +914,7 @@ def shape(self, newshape): """ - dpnp.reshape(self, newshape) - - @property - def shape(self): - """ - """ - - return self._array_obj.shape + dpnp.reshape(self, newshape=newshape) @property def size(self): @@ -1004,9 +996,9 @@ def transpose(self, *axes): See Also -------- - :obj:`dpnp.transpose` : Equivalent function. - :obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed. - :obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data. + :obj:`dpnp.transpose` : Equivalent function. + :obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed. + :obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data. Examples -------- diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index 4483d4f3faa..31fefaec9dd 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -513,26 +513,81 @@ def repeat(x1, repeats, axis=None): return call_origin(numpy.repeat, x1, repeats, axis) -def reshape(x1, newshape, order='C'): +def reshape(x, /, newshape, order='C', copy=None): """ Gives a new shape to an array without changing its data. For full documentation refer to :obj:`numpy.reshape`. + Parameters + ---------- + x : {dpnp_array, usm_ndarray} + Array to be reshaped. + newshape : int or tuple of ints + The new shape should be compatible with the original shape. If + an integer, then the result will be a 1-D array of that length. + One shape dimension can be -1. In this case, the value is + inferred from the length of the array and remaining dimensions. + order : {'C', 'F'}, optional + Read the elements of `x` using this index order, and place the + elements into the reshaped array using this index order. 'C' + means to read / write the elements using C-like index order, + with the last axis index changing fastest, back to the first + axis index changing slowest. 'F' means to read / write the + elements using Fortran-like index order, with the first index + changing fastest, and the last index changing slowest. Note that + the 'C' and 'F' options take no account of the memory layout of + the underlying array, and only refer to the order of indexing. + copy : bool, optional + Boolean indicating whether or not to copy the input array. + If ``True``, the result array will always be a copy of input `x`. + If ``False``, the result array can never be a copy + and a ValueError exception will be raised in case the copy is necessary. + If ``None``, the result array will reuse existing memory buffer of `x` + if possible and copy otherwise. Default: None. + + Returns + ------- + y : dpnp.ndarray + This will be a new view object if possible; otherwise, it will + be a copy. Note there is no guarantee of the *memory layout* (C- or + Fortran- contiguous) of the returned array. + Limitations ----------- - Only 'C' order is supported. + Parameter `order` is supported only with values ``"C"`` and ``"F"``. + + See Also + -------- + :obj:`dpnp.ndarray.reshape` : Equivalent method. + + Examples + -------- + >>> import dpnp as dp + >>> a = dp.array([[1, 2, 3], [4, 5, 6]]) + >>> dp.reshape(a, 6) + array([1, 2, 3, 4, 5, 6]) + >>> dp.reshape(a, 6, order='F') + array([1, 4, 2, 5, 3, 6]) + + >>> dp.reshape(a, (3, -1)) # the unspecified value is inferred to be 2 + array([[1, 2], + [3, 4], + [5, 6]]) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc: - if order != 'C': - pass - else: - return dpnp_reshape(x1_desc, newshape, order).get_pyobj() + if newshape is None: + newshape = x.shape + + if order is None: + order = 'C' + elif not order in "cfCF": + raise ValueError(f"order must be one of 'C' or 'F' (got {order})") - return call_origin(numpy.reshape, x1, newshape, order) + usm_arr = dpnp.get_usm_ndarray(x) + usm_arr = dpt.reshape(usm_arr, shape=newshape, order=order, copy=copy) + return dpnp_array._create_from_usm_ndarray(usm_arr) def rollaxis(x1, axis, start=0): From 10d4826d9586659ac28e7617e41a75f5bae469eb Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 4 May 2023 09:44:51 -0500 Subject: [PATCH 3/6] added dpnp.shape() and unmuted more tests --- dpnp/dpnp_array.py | 11 ++-- dpnp/dpnp_iface_manipulation.py | 44 +++++++++++++ tests/skipped_tests.tbl | 16 +---- tests/skipped_tests_gpu.tbl | 18 ++--- .../cupy/manipulation_tests/test_shape.py | 66 +++++++++++++++---- 5 files changed, 111 insertions(+), 44 deletions(-) diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 0a9b56a610a..c8a910efcf2 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -1011,9 +1011,6 @@ def transpose(self, *axes): array([[1, 3], [2, 4]]) >>> a.transpose((1, 0)) - array([[1, 3], - [2, 4]]) - >>> a.transpose(1, 0) array([[1, 3], [2, 4]]) @@ -1029,11 +1026,15 @@ def transpose(self, *axes): if ndim < 2: return self + axes_len = len(axes) + if axes_len == 1 and isinstance(axes[0], tuple): + axes = axes[0] + res = self.__new__(dpnp_array) - if ndim == 2: + if ndim == 2 and axes_len == 0: res._array_obj = self._array_obj.T else: - if len(axes) == 0 or len(axes) == 1 and axes[0] is None: + if len(axes) == 0 or axes[0] is None: # self.transpose().shape == self.shape[::-1] # self.transpose(None).shape == self.shape[::-1] axes = tuple((ndim - x - 1) for x in range(ndim)) diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index 31fefaec9dd..eb12d23e65c 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -68,6 +68,7 @@ "repeat", "reshape", "rollaxis", + "shape", "squeeze", "stack", "swapaxes", @@ -638,6 +639,49 @@ def rollaxis(x1, axis, start=0): return call_origin(numpy.rollaxis, x1, axis, start) +def shape(a): + """ + Return the shape of an array. + + For full documentation refer to :obj:`numpy.shape`. + + Parameters + ---------- + a : array_like + Input array. + + Returns + ------- + shape : tuple of ints + The elements of the shape tuple give the lengths of the + corresponding array dimensions. + + See Also + -------- + len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with + ``N>=1``. + :obj:`dpnp.ndarray.shape` : Equivalent array method. + + Examples + -------- + >>> import dpnp as dp + >>> dp.shape(dp.eye(3)) + (3, 3) + >>> dp.shape([[1, 3]]) + (1, 2) + >>> dp.shape([0]) + (1,) + >>> dp.shape(0) + () + + """ + + if dpnp.is_supported_array_type(a): + return a.shape + else: + return numpy.shape(a) + + def squeeze(x1, axis=None): """ Remove single-dimensional entries from the shape of an array. diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index cc4bdd57c95..9c9a0745760 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -722,19 +722,9 @@ tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_ tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_7_{order_init='C', order_reshape='A', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_8_{order_init='C', order_reshape='A', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_9_{order_init='C', order_reshape='c', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_external_reshape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_nocopy_reshape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_nocopy_reshape_with_order -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape2 -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_strides -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_with_unknown_dimension -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_transposed_reshape2 -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_0_{shape=(2, 3)}::test_shape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_0_{shape=(2, 3)}::test_shape_list -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_1_{shape=()}::test_shape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_1_{shape=()}::test_shape_list -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_2_{shape=(4,)}::test_shape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_2_{shape=(4,)}::test_shape_list +tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_zerosize +tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_zerosize2 + tests/third_party/cupy/manipulation_tests/test_tiling.py::TestRepeatRepeatsNdarray::test_func tests/third_party/cupy/manipulation_tests/test_tiling.py::TestRepeatRepeatsNdarray::test_method tests/third_party/cupy/manipulation_tests/test_tiling.py::TestTileFailure_param_0_{reps=-1}::test_tile_failure diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index e1cfd00028d..bf22e0fad8c 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -243,7 +243,7 @@ tests/third_party/cupy/manipulation_tests/test_basic.py::TestCopytoFromScalar_pa tests/third_party/cupy/manipulation_tests/test_basic.py::TestCopytoFromScalar_param_32_{dst_shape=(2, 2), src=True}::test_copyto_where tests/third_party/cupy/manipulation_tests/test_basic.py::TestCopytoFromScalar_param_33_{dst_shape=(2, 2), src=False}::test_copyto_where tests/third_party/cupy/manipulation_tests/test_basic.py::TestCopytoFromScalar_param_34_{dst_shape=(2, 2), src=(1+1j)}::test_copyto_where -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_zerosize + tests/third_party/cupy/math_tests/test_sumprod.py::TestCumprod::test_cumprod_out_noncontiguous tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_0_{axis=0}::test_cumsum_axis_out_noncontiguous tests/third_party/cupy/math_tests/test_sumprod.py::TestCumsum_param_0_{axis=0}::test_cumsum_out_noncontiguous @@ -883,19 +883,9 @@ tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_ tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_7_{order_init='C', order_reshape='A', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_8_{order_init='C', order_reshape='A', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_9_{order_init='C', order_reshape='c', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_external_reshape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_nocopy_reshape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_nocopy_reshape_with_order -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape2 -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_strides -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_with_unknown_dimension -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_transposed_reshape2 -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_0_{shape=(2, 3)}::test_shape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_0_{shape=(2, 3)}::test_shape_list -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_1_{shape=()}::test_shape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_1_{shape=()}::test_shape_list -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_2_{shape=(4,)}::test_shape -tests/third_party/cupy/manipulation_tests/test_shape.py::TestShape_param_2_{shape=(4,)}::test_shape_list +tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_zerosize +tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_zerosize2 + tests/third_party/cupy/manipulation_tests/test_tiling.py::TestRepeatRepeatsNdarray::test_func tests/third_party/cupy/manipulation_tests/test_tiling.py::TestRepeatRepeatsNdarray::test_method tests/third_party/cupy/manipulation_tests/test_tiling.py::TestTileFailure_param_0_{reps=-1}::test_tile_failure diff --git a/tests/third_party/cupy/manipulation_tests/test_shape.py b/tests/third_party/cupy/manipulation_tests/test_shape.py index b80437dba89..33d094bf827 100644 --- a/tests/third_party/cupy/manipulation_tests/test_shape.py +++ b/tests/third_party/cupy/manipulation_tests/test_shape.py @@ -29,19 +29,19 @@ def test_shape_list(self): @testing.gpu class TestReshape(unittest.TestCase): - def test_reshape_strides(self): + def test_reshape_shapes(self): def func(xp): a = testing.shaped_arange((1, 1, 1, 2, 2), xp) - return a.strides - self.assertEqual(func(numpy), func(cupy)) + return a.shape + assert func(numpy) == func(cupy) def test_reshape2(self): def func(xp): a = xp.zeros((8,), dtype=xp.float32) - return a.reshape((1, 1, 1, 4, 1, 2)).strides - self.assertEqual(func(numpy), func(cupy)) + return a.reshape((1, 1, 1, 4, 1, 2)).shape + assert func(numpy) == func(cupy) - @testing.for_orders('CFA') + @testing.for_orders('CF') @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() def test_nocopy_reshape(self, xp, dtype, order): @@ -50,7 +50,7 @@ def test_nocopy_reshape(self, xp, dtype, order): b[1] = 1 return a - @testing.for_orders('CFA') + @testing.for_orders('CF') @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() def test_nocopy_reshape_with_order(self, xp, dtype, order): @@ -59,13 +59,13 @@ def test_nocopy_reshape_with_order(self, xp, dtype, order): b[1] = 1 return a - @testing.for_orders('CFA') + @testing.for_orders('CF') @testing.numpy_cupy_array_equal() def test_transposed_reshape2(self, xp, order): a = testing.shaped_arange((2, 3, 4), xp).transpose(2, 0, 1) return a.reshape(2, 3, 4, order=order) - @testing.for_orders('CFA') + @testing.for_orders('CF') @testing.numpy_cupy_array_equal() def test_reshape_with_unknown_dimension(self, xp, order): a = testing.shaped_arange((2, 3, 4), xp) @@ -95,17 +95,59 @@ def test_reshape_zerosize_invalid(self): with pytest.raises(ValueError): a.reshape(()) + @pytest.mark.skip("until dpctl gh-1197 is resolved") + def test_reshape_zerosize_invalid_unknown(self): + for xp in (numpy, cupy): + a = xp.zeros((0,)) + with pytest.raises(ValueError): + a.reshape((-1, 0)) + @testing.numpy_cupy_array_equal() def test_reshape_zerosize(self, xp): a = xp.zeros((0,)) - return a.reshape((0,)) - - @testing.for_orders('CFA') + b = a.reshape((0,)) + assert b.base is a + return b + + @testing.for_orders('CF') + @testing.numpy_cupy_array_equal(strides_check=True) + def test_reshape_zerosize2(self, xp, order): + a = xp.zeros((2, 0, 3)) + b = a.reshape((5, 0, 4), order=order) + assert b.base is a + return b + + @testing.for_orders('CF') @testing.numpy_cupy_array_equal() def test_external_reshape(self, xp, order): a = xp.zeros((8,), dtype=xp.float32) return xp.reshape(a, (1, 1, 1, 4, 1, 2), order=order) + def _test_ndim_limit(self, xp, ndim, dtype, order): + idx = [1]*ndim + idx[-1] = ndim + a = xp.ones(ndim, dtype=dtype) + a = a.reshape(idx, order=order) + assert a.ndim == ndim + return a + + @pytest.mark.skip("until dpctl gh-1196 is resolved") + @testing.for_orders('CF') + @testing.for_all_dtypes() + @testing.numpy_cupy_array_equal() + def test_ndim_limit1(self, xp, dtype, order): + # from cupy/cupy#4193 + a = self._test_ndim_limit(xp, 32, dtype, order) + return a + + @testing.for_orders('CF') + @testing.for_all_dtypes() + def test_ndim_limit2(self, dtype, order): + # from cupy/cupy#4193 + for xp in (numpy, cupy): + with pytest.raises(ValueError): + self._test_ndim_limit(xp, 33, dtype, order) + @testing.gpu class TestRavel(unittest.TestCase): From 98c7c3bb13001bc8bad745cb9dd529375de7e809 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Fri, 5 May 2023 08:47:08 -0500 Subject: [PATCH 4/6] fixed compiling issue & unmuted reshaper tests with order param --- .../kernels/dpnp_krnl_mathematical.cpp | 5 ++- dpnp/backend/src/dpnp_utils.hpp | 8 +++++ tests/skipped_tests.tbl | 36 ------------------- tests/skipped_tests_gpu.tbl | 36 ------------------- .../cupy/manipulation_tests/test_shape.py | 26 +++++++------- tests/third_party/cupy/testing/__init__.py | 2 +- 6 files changed, 27 insertions(+), 86 deletions(-) diff --git a/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp b/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp index cbcd191fae6..7ad8061dfb6 100644 --- a/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp @@ -170,11 +170,14 @@ DPCTLSyclEventRef dpnp_elemwise_absolute_c(DPCTLSyclQueueRef q_ref, sycl::vec<_DataType_input, vec_sz> data_vec = sg.load(input_ptrT(&array1[start])); +#if (__SYCL_COMPILER_VERSION < __SYCL_COMPILER_VECTOR_ABS_CHANGED) // sycl::abs() returns unsigned integers only, so explicit casting to signed ones is required using result_absT = typename cl::sycl::detail::make_unsigned<_DataType_output>::type; sycl::vec<_DataType_output, vec_sz> res_vec = dpnp_vec_cast<_DataType_output, result_absT, vec_sz>(sycl::abs(data_vec)); - +#else + sycl::vec<_DataType_output, vec_sz> res_vec = sycl::abs(data_vec); +#endif sg.store(result_ptrT(&result[start]), res_vec); } else diff --git a/dpnp/backend/src/dpnp_utils.hpp b/dpnp/backend/src/dpnp_utils.hpp index 985d5a61494..717bb6ea7bf 100644 --- a/dpnp/backend/src/dpnp_utils.hpp +++ b/dpnp/backend/src/dpnp_utils.hpp @@ -49,6 +49,14 @@ #define __SYCL_COMPILER_VERSION_REQUIRED 20221102L #endif +/** + * Version of SYCL DPC++ 2023 compiler where a return type of sycl::abs() is changed + * from unsinged integer to signed one of input vector. + */ +#ifndef __SYCL_COMPILER_VECTOR_ABS_CHANGED +#define __SYCL_COMPILER_VECTOR_ABS_CHANGED 20230503L +#endif + /** * Version of Intel MKL at which transition to OneMKL release 2023.0.0 occurs. */ diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index 9c9a0745760..cc758d025fd 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -686,42 +686,6 @@ tests/third_party/cupy/manipulation_tests/test_shape.py::TestRavel::test_ravel2 tests/third_party/cupy/manipulation_tests/test_shape.py::TestRavel::test_ravel3 tests/third_party/cupy/manipulation_tests/test_shape.py::TestRavel::test_external_ravel tests/third_party/cupy/manipulation_tests/test_shape.py::TestRavel::test_ravel -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_0_{order_init='C', order_reshape='C', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_10_{order_init='C', order_reshape='c', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_11_{order_init='C', order_reshape='c', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_12_{order_init='C', order_reshape='f', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_13_{order_init='C', order_reshape='f', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_14_{order_init='C', order_reshape='f', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_15_{order_init='C', order_reshape='a', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_16_{order_init='C', order_reshape='a', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_17_{order_init='C', order_reshape='a', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_18_{order_init='F', order_reshape='C', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_19_{order_init='F', order_reshape='C', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_1_{order_init='C', order_reshape='C', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_20_{order_init='F', order_reshape='C', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_21_{order_init='F', order_reshape='F', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_22_{order_init='F', order_reshape='F', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_23_{order_init='F', order_reshape='F', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_24_{order_init='F', order_reshape='A', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_25_{order_init='F', order_reshape='A', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_26_{order_init='F', order_reshape='A', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_27_{order_init='F', order_reshape='c', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_28_{order_init='F', order_reshape='c', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_29_{order_init='F', order_reshape='c', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_2_{order_init='C', order_reshape='C', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_30_{order_init='F', order_reshape='f', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_31_{order_init='F', order_reshape='f', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_32_{order_init='F', order_reshape='f', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_33_{order_init='F', order_reshape='a', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_34_{order_init='F', order_reshape='a', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_35_{order_init='F', order_reshape='a', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_3_{order_init='C', order_reshape='F', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_4_{order_init='C', order_reshape='F', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_5_{order_init='C', order_reshape='F', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_6_{order_init='C', order_reshape='A', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_7_{order_init='C', order_reshape='A', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_8_{order_init='C', order_reshape='A', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_9_{order_init='C', order_reshape='c', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_zerosize tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_zerosize2 diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index bf22e0fad8c..f9a298327db 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -847,42 +847,6 @@ tests/third_party/cupy/manipulation_tests/test_shape.py::TestRavel::test_ravel2 tests/third_party/cupy/manipulation_tests/test_shape.py::TestRavel::test_ravel3 tests/third_party/cupy/manipulation_tests/test_shape.py::TestRavel::test_external_ravel tests/third_party/cupy/manipulation_tests/test_shape.py::TestRavel::test_ravel -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_0_{order_init='C', order_reshape='C', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_10_{order_init='C', order_reshape='c', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_11_{order_init='C', order_reshape='c', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_12_{order_init='C', order_reshape='f', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_13_{order_init='C', order_reshape='f', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_14_{order_init='C', order_reshape='f', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_15_{order_init='C', order_reshape='a', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_16_{order_init='C', order_reshape='a', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_17_{order_init='C', order_reshape='a', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_18_{order_init='F', order_reshape='C', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_19_{order_init='F', order_reshape='C', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_1_{order_init='C', order_reshape='C', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_20_{order_init='F', order_reshape='C', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_21_{order_init='F', order_reshape='F', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_22_{order_init='F', order_reshape='F', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_23_{order_init='F', order_reshape='F', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_24_{order_init='F', order_reshape='A', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_25_{order_init='F', order_reshape='A', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_26_{order_init='F', order_reshape='A', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_27_{order_init='F', order_reshape='c', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_28_{order_init='F', order_reshape='c', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_29_{order_init='F', order_reshape='c', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_2_{order_init='C', order_reshape='C', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_30_{order_init='F', order_reshape='f', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_31_{order_init='F', order_reshape='f', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_32_{order_init='F', order_reshape='f', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_33_{order_init='F', order_reshape='a', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_34_{order_init='F', order_reshape='a', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_35_{order_init='F', order_reshape='a', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_3_{order_init='C', order_reshape='F', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_4_{order_init='C', order_reshape='F', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_5_{order_init='C', order_reshape='F', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_6_{order_init='C', order_reshape='A', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_7_{order_init='C', order_reshape='A', shape_in_out=((6,), (2, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_8_{order_init='C', order_reshape='A', shape_in_out=((3, 3, 3), (9, 3))}::test_reshape_contiguity -tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshapeOrder_param_9_{order_init='C', order_reshape='c', shape_in_out=((2, 3), (1, 6, 1))}::test_reshape_contiguity tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_zerosize tests/third_party/cupy/manipulation_tests/test_shape.py::TestReshape::test_reshape_zerosize2 diff --git a/tests/third_party/cupy/manipulation_tests/test_shape.py b/tests/third_party/cupy/manipulation_tests/test_shape.py index 33d094bf827..826c0e49011 100644 --- a/tests/third_party/cupy/manipulation_tests/test_shape.py +++ b/tests/third_party/cupy/manipulation_tests/test_shape.py @@ -28,6 +28,8 @@ def test_shape_list(self): @testing.gpu class TestReshape(unittest.TestCase): + # order = 'A' is out of support currently + _supported_orders = 'CF' def test_reshape_shapes(self): def func(xp): @@ -41,7 +43,7 @@ def func(xp): return a.reshape((1, 1, 1, 4, 1, 2)).shape assert func(numpy) == func(cupy) - @testing.for_orders('CF') + @testing.for_orders(_supported_orders) @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() def test_nocopy_reshape(self, xp, dtype, order): @@ -50,7 +52,7 @@ def test_nocopy_reshape(self, xp, dtype, order): b[1] = 1 return a - @testing.for_orders('CF') + @testing.for_orders(_supported_orders) @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() def test_nocopy_reshape_with_order(self, xp, dtype, order): @@ -59,13 +61,13 @@ def test_nocopy_reshape_with_order(self, xp, dtype, order): b[1] = 1 return a - @testing.for_orders('CF') + @testing.for_orders(_supported_orders) @testing.numpy_cupy_array_equal() def test_transposed_reshape2(self, xp, order): a = testing.shaped_arange((2, 3, 4), xp).transpose(2, 0, 1) return a.reshape(2, 3, 4, order=order) - @testing.for_orders('CF') + @testing.for_orders(_supported_orders) @testing.numpy_cupy_array_equal() def test_reshape_with_unknown_dimension(self, xp, order): a = testing.shaped_arange((2, 3, 4), xp) @@ -95,7 +97,6 @@ def test_reshape_zerosize_invalid(self): with pytest.raises(ValueError): a.reshape(()) - @pytest.mark.skip("until dpctl gh-1197 is resolved") def test_reshape_zerosize_invalid_unknown(self): for xp in (numpy, cupy): a = xp.zeros((0,)) @@ -109,7 +110,7 @@ def test_reshape_zerosize(self, xp): assert b.base is a return b - @testing.for_orders('CF') + @testing.for_orders(_supported_orders) @testing.numpy_cupy_array_equal(strides_check=True) def test_reshape_zerosize2(self, xp, order): a = xp.zeros((2, 0, 3)) @@ -117,7 +118,7 @@ def test_reshape_zerosize2(self, xp, order): assert b.base is a return b - @testing.for_orders('CF') + @testing.for_orders(_supported_orders) @testing.numpy_cupy_array_equal() def test_external_reshape(self, xp, order): a = xp.zeros((8,), dtype=xp.float32) @@ -131,8 +132,7 @@ def _test_ndim_limit(self, xp, ndim, dtype, order): assert a.ndim == ndim return a - @pytest.mark.skip("until dpctl gh-1196 is resolved") - @testing.for_orders('CF') + @testing.for_orders(_supported_orders) @testing.for_all_dtypes() @testing.numpy_cupy_array_equal() def test_ndim_limit1(self, xp, dtype, order): @@ -140,7 +140,8 @@ def test_ndim_limit1(self, xp, dtype, order): a = self._test_ndim_limit(xp, 32, dtype, order) return a - @testing.for_orders('CF') + @pytest.mark.skip("no max ndim limit for reshape in dpctl") + @testing.for_orders(_supported_orders) @testing.for_all_dtypes() def test_ndim_limit2(self, dtype, order): # from cupy/cupy#4193 @@ -181,7 +182,9 @@ def test_external_ravel(self, xp): @testing.parameterize(*testing.product({ 'order_init': ['C', 'F'], - 'order_reshape': ['C', 'F', 'A', 'c', 'f', 'a'], + # order = 'A' is out of support currently + # 'order_reshape': ['C', 'F', 'A', 'c', 'f', 'a'], + 'order_reshape': ['C', 'F', 'c', 'f'], 'shape_in_out': [((2, 3), (1, 6, 1)), # (shape_init, shape_final) ((6,), (2, 3)), ((3, 3, 3), (9, 3))], @@ -203,5 +206,4 @@ def test_reshape_contiguity(self): assert b_cupy.flags.f_contiguous == b_numpy.flags.f_contiguous assert b_cupy.flags.c_contiguous == b_numpy.flags.c_contiguous - testing.assert_array_equal(b_cupy.strides, b_numpy.strides) testing.assert_array_equal(b_cupy, b_numpy) diff --git a/tests/third_party/cupy/testing/__init__.py b/tests/third_party/cupy/testing/__init__.py index 09f30ade6d9..56b5d052958 100644 --- a/tests/third_party/cupy/testing/__init__.py +++ b/tests/third_party/cupy/testing/__init__.py @@ -7,7 +7,7 @@ from tests.third_party.cupy.testing.array import assert_allclose from tests.third_party.cupy.testing.array import assert_array_almost_equal # from tests.third_party.cupy.testing.array import assert_array_almost_equal_nulp -# from tests.third_party.cupy.testing.array import assert_array_equal +from tests.third_party.cupy.testing.array import assert_array_equal # from tests.third_party.cupy.testing.array import assert_array_less # from tests.third_party.cupy.testing.array import assert_array_list_equal # from tests.third_party.cupy.testing.array import assert_array_max_ulp From 9346500f72a24bb1266479f2bf71b1265cbe59e5 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Tue, 13 Jun 2023 07:17:30 -0500 Subject: [PATCH 5/6] Resolve merge issues --- dpnp/backend/kernels/dpnp_krnl_mathematical.cpp | 1 + dpnp/backend/src/dpnp_utils.hpp | 8 -------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp b/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp index f4add077ab9..b82cbb49b1a 100644 --- a/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_mathematical.cpp @@ -177,6 +177,7 @@ DPCTLSyclEventRef dpnp_elemwise_absolute_c(DPCTLSyclQueueRef q_ref, dpnp_vec_cast<_DataType_output, result_absT, vec_sz>(sycl::abs(data_vec)); #else sycl::vec<_DataType_output, vec_sz> res_vec = sycl::abs(data_vec); +#endif sg.store(result_ptrT(&result[start]), res_vec); } diff --git a/dpnp/backend/src/dpnp_utils.hpp b/dpnp/backend/src/dpnp_utils.hpp index 66d9787f38e..6c1bda90cba 100644 --- a/dpnp/backend/src/dpnp_utils.hpp +++ b/dpnp/backend/src/dpnp_utils.hpp @@ -57,14 +57,6 @@ #define __SYCL_COMPILER_VERSION_REQUIRED 20221102L #endif -/** - * Version of SYCL DPC++ 2023 compiler where a return type of sycl::abs() is changed - * from unsinged integer to signed one of input vector. - */ -#ifndef __SYCL_COMPILER_VECTOR_ABS_CHANGED -#define __SYCL_COMPILER_VECTOR_ABS_CHANGED 20230503L -#endif - /** * Version of Intel MKL at which transition to OneMKL release 2023.0.0 occurs. */ From e90956b2ea0604ef79192932d36404cbc22cd52b Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Tue, 13 Jun 2023 09:37:16 -0500 Subject: [PATCH 6/6] resolve type mismatch on Win --- tests/third_party/cupy/math_tests/test_sumprod.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/third_party/cupy/math_tests/test_sumprod.py b/tests/third_party/cupy/math_tests/test_sumprod.py index d9fe3b22b26..17250257fe8 100644 --- a/tests/third_party/cupy/math_tests/test_sumprod.py +++ b/tests/third_party/cupy/math_tests/test_sumprod.py @@ -41,7 +41,7 @@ def test_sum_all2(self, xp, dtype): return a.sum() @testing.for_all_dtypes() - @testing.numpy_cupy_allclose() + @testing.numpy_cupy_allclose(type_check=False) def test_sum_all_transposed(self, xp, dtype): a = testing.shaped_arange((2, 3, 4), xp, dtype).transpose(2, 0, 1) return a.sum()