diff --git a/dpnp/backend/include/dpnp_iface.hpp b/dpnp/backend/include/dpnp_iface.hpp index 02402a48c76c..755a095ebe4a 100644 --- a/dpnp/backend/include/dpnp_iface.hpp +++ b/dpnp/backend/include/dpnp_iface.hpp @@ -141,6 +141,17 @@ INP_DLLEXPORT void dpnp_any_c(const void* array, void* result, const size_t size template INP_DLLEXPORT void dpnp_arange_c(size_t start, size_t step, void* result1, size_t size); +/** + * @ingroup BACKEND_API + * @brief Copy of the array, cast to a specified type. + * + * @param [in] array Input array. + * @param [out] result Output array. + * @param [in] size Number of input elements in `array`. + */ +template +INP_DLLEXPORT void dpnp_astype_c(const void* array, void* result, const size_t size); + /** * @ingroup BACKEND_API * @brief Implementation of full function diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index e8881980ce02..b7354c4bc0a5 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -73,6 +73,7 @@ enum class DPNPFuncName : size_t DPNP_FN_ARGMAX, /**< Used in numpy.argmax() implementation */ DPNP_FN_ARGMIN, /**< Used in numpy.argmin() implementation */ DPNP_FN_ARGSORT, /**< Used in numpy.argsort() implementation */ + DPNP_FN_ASTYPE, /**< Used in numpy.astype() implementation */ DPNP_FN_BITWISE_AND, /**< Used in numpy.bitwise_and() implementation */ DPNP_FN_BITWISE_OR, /**< Used in numpy.bitwise_or() implementation */ DPNP_FN_BITWISE_XOR, /**< Used in numpy.bitwise_xor() implementation */ diff --git a/dpnp/backend/kernels/dpnp_krnl_common.cpp b/dpnp/backend/kernels/dpnp_krnl_common.cpp index 4458e313a157..1c49b7c56bfc 100644 --- a/dpnp/backend/kernels/dpnp_krnl_common.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_common.cpp @@ -35,6 +35,42 @@ namespace mkl_blas = oneapi::mkl::blas; namespace mkl_lapack = oneapi::mkl::lapack; +template +class dpnp_astype_c_kernel; + +template +void dpnp_astype_c(const void* array1_in, void* result1, const size_t size) +{ + cl::sycl::event event; + + const _DataType* array_in = reinterpret_cast(array1_in); + _ResultType* result = reinterpret_cast<_ResultType*>(result1); + + if ((array_in == nullptr) || (result == nullptr)) + { + return; + } + + if (size == 0) + { + return; + } + + cl::sycl::range<1> gws(size); + auto kernel_parallel_for_func = [=](cl::sycl::id<1> global_id) { + size_t i = global_id[0]; + result[i] = array_in[i]; + }; + + auto kernel_func = [&](cl::sycl::handler& cgh) { + cgh.parallel_for>(gws, kernel_parallel_for_func); + }; + + event = DPNP_QUEUE.submit(kernel_func); + + event.wait(); +} + template class dpnp_dot_c_kernel; @@ -324,6 +360,33 @@ void dpnp_matmul_c(void* array1_in, void* array2_in, void* result1, size_t size_ void func_map_init_linalg(func_map_t& fmap) { + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_INT] = {eft_INT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_BLN][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_INT][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_INT] = {eft_INT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_LNG][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_INT] = {eft_INT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_FLT][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_BLN] = {eft_BLN, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_INT] = {eft_INT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_LNG] = {eft_LNG, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_FLT] = {eft_FLT, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_astype_c}; + fmap[DPNPFuncName::DPNP_FN_ASTYPE][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_astype_c, std::complex>}; + fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_dot_c}; fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_LNG] = {eft_LNG, (void*)dpnp_dot_c}; fmap[DPNPFuncName::DPNP_FN_DOT][eft_INT][eft_FLT] = {eft_DBL, (void*)dpnp_dot_c}; diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index b1dca274bd1d..923bb2e79dc6 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -46,6 +46,7 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_ARGMAX DPNP_FN_ARGMIN DPNP_FN_ARGSORT + DPNP_FN_ASTYPE DPNP_FN_BITWISE_AND DPNP_FN_BITWISE_OR DPNP_FN_BITWISE_XOR diff --git a/dpnp/dpnp_algo/dpnp_algo.pyx b/dpnp/dpnp_algo/dpnp_algo.pyx index 40467cf231f7..f7e95b555ad2 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pyx +++ b/dpnp/dpnp_algo/dpnp_algo.pyx @@ -69,6 +69,7 @@ include "dpnp_algo_trigonometric.pyx" ctypedef void(*fptr_dpnp_arange_t)(size_t, size_t, void *, size_t) +ctypedef void(*fptr_dpnp_astype_t)(const void *, void * , const size_t) ctypedef void(*fptr_dpnp_initval_t)(void *, void * , size_t) @@ -125,10 +126,16 @@ cpdef dparray dpnp_array(obj, dtype=None): cpdef dparray dpnp_astype(dparray array1, dtype_target): - cdef dparray result = dparray(array1.shape, dtype=dtype_target) + cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype) + cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(dtype_target) - for i in range(result.size): - result[i] = array1[i] + cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ASTYPE, param1_type, param2_type) + + result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type) + cdef dparray result = dparray(array1.shape, dtype=result_type) + + cdef fptr_dpnp_astype_t func = kernel_data.ptr + func(array1.get_data(), result.get_data(), array1.size) return result diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index aeebe807b318..7d1ddbbb0583 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -1,3 +1,21 @@ +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-float64] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-float32] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-int64] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-int32] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-bool] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-bool_] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-float64] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-float32] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-int64] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-int32] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-bool] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-bool_] +tests/test_dparray.py::test_astype[[]-complex-float64] +tests/test_dparray.py::test_astype[[]-complex-float32] +tests/test_dparray.py::test_astype[[]-complex-int64] +tests/test_dparray.py::test_astype[[]-complex-int32] +tests/test_dparray.py::test_astype[[]-complex-bool] +tests/test_dparray.py::test_astype[[]-complex-bool_] tests/test_linalg.py::test_cond[-1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]] tests/test_linalg.py::test_cond[1-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]] tests/test_linalg.py::test_cond[-2-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]] @@ -154,7 +172,6 @@ tests/test_linalg.py::test_svd[(5,3)-complex128] tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: x] tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: dpnp.asarray(x).astype(dpnp.int8)] tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: dpnp.asarray(x).astype(dpnp.complex64)] -tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: dpnp.asarray(x).astype(object)] tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: [(i, i) for i in x]] tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: dpnp.vstack([x, x]).T] tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[ lambda x: (dpnp.asarray([(i, i) for i in x], [("a", int), ("b", int)]).view(dpnp.recarray))] diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 634d05ee200b..a0ff058fb6a4 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -1,3 +1,21 @@ +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-float64] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-float32] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-int64] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-int32] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-bool] +tests/test_dparray.py::test_astype[[-2, -1, 0, 1, 2]-complex-bool_] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-float64] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-float32] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-int64] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-int32] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-bool] +tests/test_dparray.py::test_astype[[[-2, -1], [1, 2]]-complex-bool_] +tests/test_dparray.py::test_astype[[]-complex-float64] +tests/test_dparray.py::test_astype[[]-complex-float32] +tests/test_dparray.py::test_astype[[]-complex-int64] +tests/test_dparray.py::test_astype[[]-complex-int32] +tests/test_dparray.py::test_astype[[]-complex-bool] +tests/test_dparray.py::test_astype[[]-complex-bool_] tests/test_dot.py::test_dot_arange[float32] tests/test_dot.py::test_dot_arange[float64] tests/test_dot.py::test_dot_ones[float32] diff --git a/tests/test_dparray.py b/tests/test_dparray.py new file mode 100644 index 000000000000..84440cf97a79 --- /dev/null +++ b/tests/test_dparray.py @@ -0,0 +1,20 @@ +import dpnp +import numpy +import pytest + + +@pytest.mark.parametrize("res_dtype", + [numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool, numpy.bool_, numpy.complex], + ids=['float64', 'float32', 'int64', 'int32', 'bool', 'bool_', 'complex']) +@pytest.mark.parametrize("arr_dtype", + [numpy.float64, numpy.float32, numpy.int64, numpy.int32, numpy.bool, numpy.bool_, numpy.complex], + ids=['float64', 'float32', 'int64', 'int32', 'bool', 'bool_', 'complex']) +@pytest.mark.parametrize("arr", + [[-2, -1, 0, 1, 2], [[-2, -1], [1, 2]], []], + ids=['[-2, -1, 0, 1, 2]', '[[-2, -1], [1, 2]]', '[]']) +def test_astype(arr, arr_dtype, res_dtype): + numpy_array = numpy.array(arr, dtype=arr_dtype) + dpnp_array = dpnp.array(numpy_array) + expected = numpy_array.astype(res_dtype) + result = dpnp_array.astype(res_dtype) + numpy.testing.assert_array_equal(expected, result) diff --git a/tests_external/skipped_tests_numpy.tbl b/tests_external/skipped_tests_numpy.tbl index a47af6acc654..798126595063 100644 --- a/tests_external/skipped_tests_numpy.tbl +++ b/tests_external/skipped_tests_numpy.tbl @@ -2881,6 +2881,7 @@ tests/test_umath.py::test_rint_big_int tests/test_umath.py::TestRoundingFunctions::test_object_direct tests/test_umath.py::TestRoundingFunctions::test_object_indirect tests/test_umath.py::test_signaling_nan_exceptions +tests/test_umath.py::TestSign::test_sign_dtype_nan_object tests/test_umath.py::TestSign::test_sign tests/test_umath.py::TestSign::test_sign_dtype_object tests/test_umath.py::test_spacing