From a7bcf9db9e8c7d29ab2ae60be8bd3b82a9dca068 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Wed, 3 May 2023 06:48:49 -0500 Subject: [PATCH 1/2] Rework transpose methods to call dpctl.tensor functions --- dpnp/backend/include/dpnp_iface_fptr.hpp | 1 - .../kernels/dpnp_krnl_manipulation.cpp | 24 +----- dpnp/dpnp_algo/dpnp_algo.pxd | 2 - dpnp/dpnp_algo/dpnp_algo_manipulation.pxi | 74 ----------------- dpnp/dpnp_array.py | 67 ++++++++++++---- dpnp/dpnp_iface_manipulation.py | 79 +++++++++++-------- dpnp/dpnp_iface_statistics.py | 3 +- tests/skipped_tests.tbl | 4 - tests/skipped_tests_gpu.tbl | 9 +-- .../cupy/linalg_tests/test_eigenvalue.py | 3 +- 10 files changed, 103 insertions(+), 163 deletions(-) diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 653471fd1b5..d9dd2dc6b40 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -360,7 +360,6 @@ enum class DPNPFuncName : size_t DPNP_FN_TANH, /**< Used in numpy.tanh() impl */ DPNP_FN_TANH_EXT, /**< Used in numpy.tanh() impl, requires extra parameters */ DPNP_FN_TRANSPOSE, /**< Used in numpy.transpose() impl */ - DPNP_FN_TRANSPOSE_EXT, /**< Used in numpy.transpose() impl, requires extra parameters */ DPNP_FN_TRACE, /**< Used in numpy.trace() impl */ DPNP_FN_TRACE_EXT, /**< Used in numpy.trace() impl, requires extra parameters */ DPNP_FN_TRAPZ, /**< Used in numpy.trapz() impl */ diff --git a/dpnp/backend/kernels/dpnp_krnl_manipulation.cpp b/dpnp/backend/kernels/dpnp_krnl_manipulation.cpp index 8a122dbf728..2fc7832b6ba 100644 --- a/dpnp/backend/kernels/dpnp_krnl_manipulation.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_manipulation.cpp @@ -1,5 +1,5 @@ //***************************************************************************** -// Copyright (c) 2016-2020, Intel Corporation +// Copyright (c) 2016-2023, Intel Corporation // All rights reserved. // // Redistribution and use in source and binary forms, with or without @@ -211,6 +211,7 @@ void dpnp_elemwise_transpose_c(void* array1_in, size, dep_event_vec_ref); DPCTLEvent_WaitAndThrow(event_ref); + DPCTLEvent_Delete(event_ref); } template @@ -222,17 +223,6 @@ void (*dpnp_elemwise_transpose_default_c)(void*, void*, size_t) = dpnp_elemwise_transpose_c<_DataType>; -template -DPCTLSyclEventRef (*dpnp_elemwise_transpose_ext_c)(DPCTLSyclQueueRef, - void*, - const shape_elem_type*, - const shape_elem_type*, - const shape_elem_type*, - size_t, - void*, - size_t, - const DPCTLEventVectorRef) = dpnp_elemwise_transpose_c<_DataType>; - void func_map_init_manipulation(func_map_t& fmap) { fmap[DPNPFuncName::DPNP_FN_REPEAT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_repeat_default_c}; @@ -253,15 +243,5 @@ void func_map_init_manipulation(func_map_t& fmap) (void*)dpnp_elemwise_transpose_default_c}; fmap[DPNPFuncName::DPNP_FN_TRANSPOSE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_elemwise_transpose_default_c}; - - fmap[DPNPFuncName::DPNP_FN_TRANSPOSE_EXT][eft_INT][eft_INT] = {eft_INT, - (void*)dpnp_elemwise_transpose_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TRANSPOSE_EXT][eft_LNG][eft_LNG] = {eft_LNG, - (void*)dpnp_elemwise_transpose_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TRANSPOSE_EXT][eft_FLT][eft_FLT] = {eft_FLT, - (void*)dpnp_elemwise_transpose_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TRANSPOSE_EXT][eft_DBL][eft_DBL] = {eft_DBL, - (void*)dpnp_elemwise_transpose_ext_c}; - return; } diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 09af5667f8c..d6b5429669f 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -339,7 +339,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_TRACE DPNP_FN_TRACE_EXT DPNP_FN_TRANSPOSE - DPNP_FN_TRANSPOSE_EXT DPNP_FN_TRAPZ DPNP_FN_TRAPZ_EXT DPNP_FN_TRI @@ -555,7 +554,6 @@ cpdef dpnp_descriptor dpnp_subtract(dpnp_descriptor x1_obj, dpnp_descriptor x2_o Array manipulation routines """ cpdef dpnp_descriptor dpnp_repeat(dpnp_descriptor array1, repeats, axes=*) -cpdef dpnp_descriptor dpnp_transpose(dpnp_descriptor array1, axes=*) """ diff --git a/dpnp/dpnp_algo/dpnp_algo_manipulation.pxi b/dpnp/dpnp_algo/dpnp_algo_manipulation.pxi index 3e27af363c3..9176092b7c0 100644 --- a/dpnp/dpnp_algo/dpnp_algo_manipulation.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_manipulation.pxi @@ -42,21 +42,11 @@ __all__ += [ "dpnp_expand_dims", "dpnp_repeat", "dpnp_reshape", - "dpnp_transpose", "dpnp_squeeze", ] # C function pointer to the C library template functions -ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_custom_elemwise_transpose_1in_1out_t)(c_dpctl.DPCTLSyclQueueRef, - void * , - shape_elem_type * , - shape_elem_type * , - shape_elem_type * , - size_t, - void * , - size_t, - const c_dpctl.DPCTLEventVectorRef) ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_repeat_t)(c_dpctl.DPCTLSyclQueueRef, const void *, void * , const size_t , const size_t, const c_dpctl.DPCTLEventVectorRef) @@ -232,70 +222,6 @@ cpdef utils.dpnp_descriptor dpnp_reshape(utils.dpnp_descriptor array1, newshape, copy_when_nondefault_queue=False) -cpdef utils.dpnp_descriptor dpnp_transpose(utils.dpnp_descriptor array1, axes=None): - cdef shape_type_c input_shape = array1.shape - cdef size_t input_shape_size = array1.ndim - cdef shape_type_c result_shape = shape_type_c(input_shape_size, 1) - - cdef shape_type_c permute_axes - if axes is None: - """ - template to do transpose a tensor - input_shape=[2, 3, 4] - permute_axes=[2, 1, 0] - after application `permute_axes` to `input_shape` result: - result_shape=[4, 3, 2] - - 'do nothing' axes variable is `permute_axes=[0, 1, 2]` - - test: pytest tests/third_party/cupy/manipulation_tests/test_transpose.py::TestTranspose::test_external_transpose_all - """ - permute_axes = list(reversed([i for i in range(input_shape_size)])) - else: - permute_axes = utils.normalize_axis(axes, input_shape_size) - - for i in range(input_shape_size): - """ construct output shape """ - result_shape[i] = input_shape[permute_axes[i]] - - # convert string type names (array.dtype) to C enum DPNPFuncType - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype) - - # get the FPTR data structure - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRANSPOSE_EXT, param1_type, param1_type) - - array1_obj = array1.get_array() - - # ceate result array with type given by FPTR data - cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, - kernel_data.return_type, - None, - device=array1_obj.sycl_device, - usm_type=array1_obj.usm_type, - sycl_queue=array1_obj.sycl_queue) - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef fptr_custom_elemwise_transpose_1in_1out_t func = kernel_data.ptr - # call FPTR function - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - array1.get_data(), - input_shape.data(), - result_shape.data(), - permute_axes.data(), - input_shape_size, - result.get_data(), - array1.size, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result - - cpdef utils.dpnp_descriptor dpnp_squeeze(utils.dpnp_descriptor in_array, axis): cdef shape_type_c shape_list if axis is None: diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 5741ea0fa14..9d8d4ecacc3 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -99,15 +99,8 @@ def get_array(self): @property def T(self): - """Shape-reversed view of the array. - - If ndim < 2, then this is just a reference to the array itself. - - """ - if self.ndim < 2: - return self - else: - return dpnp.transpose(self) + """View of the transposed array.""" + return self.transpose() def to_device(self, target_device): """ @@ -1000,15 +993,61 @@ def take(self, indices, axis=None, out=None, mode='raise'): def transpose(self, *axes): """ - Returns a view of the array with axes permuted. + Returns a view of the array with axes transposed. - .. seealso:: - :obj:`dpnp.transpose` for full documentation, - :meth:`numpy.ndarray.reshape` + For full documentation refer to :obj:`numpy.ndarray.transpose`. + + Returns + ------- + y : dpnp.ndarray + View of the array with its axes suitably permuted. + See Also + -------- + :obj:`dpnp.transpose` : Equivalent function. + :obj:`dpnp.ndarray.ndarray.T` : Array property returning the array transposed. + :obj:`dpnp.ndarray.reshape` : Give a new shape to an array without changing its data. + + Examples + -------- + >>> import dpnp as dp + >>> a = dp.array([[1, 2], [3, 4]]) + >>> a + array([[1, 2], + [3, 4]]) + >>> a.transpose() + array([[1, 3], + [2, 4]]) + >>> a.transpose((1, 0)) + array([[1, 3], + [2, 4]]) + >>> a.transpose(1, 0) + array([[1, 3], + [2, 4]]) + + >>> a = dp.array([1, 2, 3, 4]) + >>> a + array([1, 2, 3, 4]) + >>> a.transpose() + array([1, 2, 3, 4]) + """ - return dpnp.transpose(self, axes) + ndim = self.ndim + if ndim < 2: + return self + + res = self.__new__(dpnp_array) + if ndim == 2: + res._array_obj = self._array_obj.T + else: + if len(axes) == 0 or len(axes) == 1 and axes[0] is None: + # self.transpose().shape == self.shape[::-1] + # self.transpose(None).shape == self.shape[::-1] + axes = tuple((ndim - x - 1) for x in range(ndim)) + + res._array_obj = dpt.permute_dims(self._array_obj, axes) + return res def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): """ diff --git a/dpnp/dpnp_iface_manipulation.py b/dpnp/dpnp_iface_manipulation.py index 567661bdb57..4483d4f3faa 100644 --- a/dpnp/dpnp_iface_manipulation.py +++ b/dpnp/dpnp_iface_manipulation.py @@ -678,54 +678,65 @@ def swapaxes(x1, axis1, axis2): return call_origin(numpy.swapaxes, x1, axis1, axis2) -def transpose(x1, axes=None): +def transpose(a, axes=None): """ - Reverse or permute the axes of an array; returns the modified array. + Returns an array with axes transposed. For full documentation refer to :obj:`numpy.transpose`. + Returns + ------- + y : dpnp.ndarray + `a` with its axes permuted. A view is returned whenever possible. + Limitations ----------- - Input array is supported as :obj:`dpnp.ndarray`. - Otherwise the function will be executed sequentially on CPU. - Value of the parameter ``axes`` likely to be replaced with ``None``. - Input array data types are limited by supported DPNP :ref:`Data types`. + Input array is supported as either :class:`dpnp.ndarray` + or :class:`dpctl.tensor.usm_ndarray`. See Also -------- + :obj:`dpnp.ndarray.transpose` : Equivalent method. :obj:`dpnp.moveaxis` : Move array axes to new positions. :obj:`dpnp.argsort` : Returns the indices that would sort an array. Examples -------- - >>> import dpnp as np - >>> x = np.arange(4).reshape((2,2)) - >>> x.shape - (2, 2) - >>> [i for i in x] - [0, 1, 2, 3] - >>> out = np.transpose(x) - >>> out.shape - (2, 2) - >>> [i for i in out] - [0, 2, 1, 3] - - """ - - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - if x1_desc: - if axes is not None: - if not any(axes): - """ - pytest tests/third_party/cupy/manipulation_tests/test_transpose.py - """ - axes = None - - result = dpnp_transpose(x1_desc, axes).get_pyobj() - - return result - - return call_origin(numpy.transpose, x1, axes=axes) + >>> import dpnp as dp + >>> a = dp.array([[1, 2], [3, 4]]) + >>> a + array([[1, 2], + [3, 4]]) + >>> dp.transpose(a) + array([[1, 3], + [2, 4]]) + + >>> a = dp.array([1, 2, 3, 4]) + >>> a + array([1, 2, 3, 4]) + >>> dp.transpose(a) + array([1, 2, 3, 4]) + + >>> a = dp.ones((1, 2, 3)) + >>> dp.transpose(a, (1, 0, 2)).shape + (2, 1, 3) + + >>> a = dp.ones((2, 3, 4, 5)) + >>> dp.transpose(a).shape + (5, 4, 3, 2) + + """ + + if isinstance(a, dpnp_array): + array = a + elif isinstance(a, dpt.usm_ndarray): + array = dpnp_array._create_from_usm_ndarray(a.get_array()) + else: + raise TypeError("An array must be any of supported type, but got {}".format(type(a))) + + if axes is None: + return array.transpose() + return array.transpose(*axes) def unique(x1, **kwargs): diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index 966a7214269..35e48e1ea94 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -290,8 +290,7 @@ def cov(x1, y=None, rowvar=True, bias=False, ddof=None, fweights=None, aweights= pass else: if not rowvar and x1.shape[0] != 1: - x1 = x1.get_array() if isinstance(x1, dpnp_array) else x1 - x1 = dpnp_array._create_from_usm_ndarray(x1.mT) + x1 = x1.T if not x1.dtype in (dpnp.float32, dpnp.float64): x1 = dpnp.astype(x1, dpnp.default_float_type(sycl_queue=x1.sycl_queue)) diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index a41b881ae7b..cc4bdd57c95 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -234,8 +234,6 @@ tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAn tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides_swapped tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_c_contiguous_no_copy tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_f_contiguous_no_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_transposed_fill -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_transposed_flatten tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten_copied tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy @@ -837,7 +835,6 @@ tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint_nega tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all2 tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_keepdims -tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed2 tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes2 @@ -886,7 +883,6 @@ tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_1 tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_9_{axis=0, func='nanprod', keepdims=True, shape=(2, 3, 4), transpose_axes=False}::test_nansum_all tests/third_party/cupy/math_tests/test_sumprod.py::TestNansumNanprodLong_param_9_{axis=0, func='nanprod', keepdims=True, shape=(2, 3, 4), transpose_axes=False}::test_nansum_axis_transposed tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all2 -tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed2 tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_1dim tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_1dim_with_discont tests/third_party/cupy/math_tests/test_trigonometric.py::TestUnwrap::test_unwrap_2dim_with_axis diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 36b17d5edbc..e1cfd00028d 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -142,7 +142,6 @@ tests/third_party/cupy/random_tests/test_distributions.py::TestDistributionsPois tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([[i, i] for i in x])] -tests/test_arraymanipulation.py::TestConcatenate::test_concatenate tests/test_histograms.py::TestHistogram::test_density tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), dpnp.int8)] @@ -376,8 +375,6 @@ tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAn tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_strides_swapped tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_c_contiguous_no_copy tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_astype_type_f_contiguous_no_copy -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_transposed_fill -tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_transposed_flatten tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten_copied tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_flatten tests/third_party/cupy/core_tests/test_ndarray_copy_and_view.py::TestArrayCopyAndView::test_isinstance_numpy_copy @@ -783,15 +780,11 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumLarge_param_9_{opt tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_float tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1 -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_invlarge tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_large tests/third_party/cupy/linalg_tests/test_product.py::TestMatrixPower::test_matrix_power_of_two tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_dot_vec2 tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_multidim_vdot -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_dot_with_out -tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_higher_order_inner tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_int_axes tests/third_party/cupy/linalg_tests/test_product.py::TestProduct::test_transposed_tensordot_with_list_axes @@ -1005,7 +998,7 @@ tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_rint_negative tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_round_ tests/third_party/cupy/math_tests/test_rounding.py::TestRounding::test_trunc -tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed + tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_all_transposed2 tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes tests/third_party/cupy/math_tests/test_sumprod.py::TestSumprod::test_sum_axes2 diff --git a/tests/third_party/cupy/linalg_tests/test_eigenvalue.py b/tests/third_party/cupy/linalg_tests/test_eigenvalue.py index fe577e32b28..1a5309e481c 100644 --- a/tests/third_party/cupy/linalg_tests/test_eigenvalue.py +++ b/tests/third_party/cupy/linalg_tests/test_eigenvalue.py @@ -8,8 +8,7 @@ def _get_hermitian(xp, a, UPLO): - # TODO: fix this, currently dpnp.transpose() doesn't support complex types - # and no dpnp_array.swapaxes() + # TODO: remove wrapping, but now there is no dpnp_array.swapaxes() a = _wrap_as_numpy_array(xp, a) _xp = numpy From c6ee0f22b9f8e685217d3e30b6689b3d338fd7f7 Mon Sep 17 00:00:00 2001 From: Anton Volkov Date: Thu, 4 May 2023 06:16:56 -0500 Subject: [PATCH 2/2] Applied review comments & added more tests --- dpnp/dpnp_array.py | 11 ++++++----- tests/test_manipulation.py | 39 +++++++++++++++++++++++++++++++++++--- 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 9d8d4ecacc3..ea4f896ba00 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -1019,9 +1019,6 @@ def transpose(self, *axes): array([[1, 3], [2, 4]]) >>> a.transpose((1, 0)) - array([[1, 3], - [2, 4]]) - >>> a.transpose(1, 0) array([[1, 3], [2, 4]]) @@ -1037,11 +1034,15 @@ def transpose(self, *axes): if ndim < 2: return self + axes_len = len(axes) + if axes_len == 1 and isinstance(axes[0], tuple): + axes = axes[0] + res = self.__new__(dpnp_array) - if ndim == 2: + if ndim == 2 and axes_len == 0: res._array_obj = self._array_obj.T else: - if len(axes) == 0 or len(axes) == 1 and axes[0] is None: + if len(axes) == 0 or axes[0] is None: # self.transpose().shape == self.shape[::-1] # self.transpose(None).shape == self.shape[::-1] axes = tuple((ndim - x - 1) for x in range(ndim)) diff --git a/tests/test_manipulation.py b/tests/test_manipulation.py index bb91f5d0d50..5c535d60f8e 100644 --- a/tests/test_manipulation.py +++ b/tests/test_manipulation.py @@ -1,5 +1,10 @@ import pytest + import numpy +from numpy.testing import ( + assert_array_equal +) + import dpnp @@ -20,7 +25,7 @@ def test_copyto_dtype(in_obj, out_dtype): result = dpnp.empty(dparr.size, dtype=out_dtype) dpnp.copyto(result, dparr) - numpy.testing.assert_array_equal(result, expected) + assert_array_equal(result, expected) @pytest.mark.usefixtures("allow_fall_back_on_numpy") @@ -32,7 +37,7 @@ def test_repeat(arr): dpnp_a = dpnp.array(arr) expected = numpy.repeat(a, 2) result = dpnp.repeat(dpnp_a, 2) - numpy.testing.assert_array_equal(expected, result) + assert_array_equal(expected, result) @pytest.mark.usefixtures("allow_fall_back_on_numpy") @@ -51,4 +56,32 @@ def test_unique(array): expected = numpy.unique(np_a) result = dpnp.unique(dpnp_a) - numpy.testing.assert_array_equal(expected, result) + assert_array_equal(expected, result) + + +class TestTranspose: + @pytest.mark.parametrize("axes", [(0, 1), (1, 0)]) + def test_2d_with_axes(self, axes): + na = numpy.array([[1, 2], [3, 4]]) + da = dpnp.array(na) + + expected = numpy.transpose(na, axes) + result = dpnp.transpose(da, axes) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("axes", [(1, 0, 2), ((1, 0, 2),)]) + def test_3d_with_packed_axes(self, axes): + na = numpy.ones((1, 2, 3)) + da = dpnp.array(na) + + expected = na.transpose(*axes) + result = da.transpose(*axes) + assert_array_equal(expected, result) + + @pytest.mark.parametrize("shape", [(10,), (2, 4), (5, 3, 7), (3, 8, 4, 1)]) + def test_none_axes(self, shape): + na = numpy.ones(shape) + da = dpnp.ones(shape) + + assert_array_equal(na.transpose(), da.transpose()) + assert_array_equal(na.transpose(None), da.transpose(None))