Skip to content

Commit

Permalink
Add support of complex types for dpnp.abs() (#1324)
Browse files Browse the repository at this point in the history
* Add support of complex types for dpnp.abs()

* Add test coverage

* State support of :class: in descriptions
  • Loading branch information
antonwolfy authored Mar 3, 2023
1 parent 4b9a5cd commit 648612d
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 117 deletions.
90 changes: 54 additions & 36 deletions dpnp/backend/kernels/dpnp_krnl_mathematical.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright (c) 2016-2020, Intel Corporation
// Copyright (c) 2016-2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -114,10 +114,10 @@ DPCTLSyclEventRef (*dpnp_around_ext_c)(DPCTLSyclQueueRef,
const int,
const DPCTLEventVectorRef) = dpnp_around_c<_DataType>;

template <typename _KernelNameSpecialization>
template <typename _KernelNameSpecialization1, typename _KernelNameSpecialization2>
class dpnp_elemwise_absolute_c_kernel;

template <typename _DataType>
template <typename _DataType_input, typename _DataType_output>
DPCTLSyclEventRef dpnp_elemwise_absolute_c(DPCTLSyclQueueRef q_ref,
const void* input1_in,
void* result1,
Expand All @@ -137,43 +137,63 @@ DPCTLSyclEventRef dpnp_elemwise_absolute_c(DPCTLSyclQueueRef q_ref,
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
sycl::event event;

DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, input1_in, size);
_DataType* array1 = input1_ptr.get_ptr();
DPNPC_ptr_adapter<_DataType> result1_ptr(q_ref, result1, size, false, true);
_DataType* result = result1_ptr.get_ptr();
_DataType_input* array1 = static_cast<_DataType_input*>(const_cast<void*>(input1_in));
_DataType_output* result = static_cast<_DataType_output*>(result1);

if constexpr (std::is_same<_DataType, double>::value || std::is_same<_DataType, float>::value)
if constexpr (is_any_v<_DataType_input, float, double, std::complex<float>, std::complex<double>>)
{
// https://docs.oneapi.com/versions/latest/onemkl/abs.html
event = oneapi::mkl::vm::abs(q, size, array1, result);
}
else
{
sycl::range<1> gws(size);
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
const size_t idx = global_id[0];
static_assert(is_any_v<_DataType_input, int32_t, int64_t>,
"Integer types are only expected to pass in 'abs' kernel");
static_assert(std::is_same_v<_DataType_input, _DataType_output>, "Result type must match a type of input data");

constexpr size_t lws = 64;
constexpr unsigned int vec_sz = 8;
constexpr sycl::access::address_space global_space = sycl::access::address_space::global_space;

auto gws_range = sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
auto lws_range = sycl::range<1>(lws);

if (array1[idx] >= 0)
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
auto sg = nd_it.get_sub_group();
const auto max_sg_size = sg.get_max_local_range()[0];
const size_t start =
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);

if (start + static_cast<size_t>(vec_sz) * max_sg_size < size)
{
result[idx] = array1[idx];
using input_ptrT = sycl::multi_ptr<_DataType_input, global_space>;
using result_ptrT = sycl::multi_ptr<_DataType_output, global_space>;

sycl::vec<_DataType_input, vec_sz> data_vec = sg.load<vec_sz>(input_ptrT(&array1[start]));

// sycl::abs() returns unsigned integers only, so explicit casting to signed ones is required
using result_absT = typename cl::sycl::detail::make_unsigned<_DataType_output>::type;
sycl::vec<_DataType_output, vec_sz> res_vec =
dpnp_vec_cast<_DataType_output, result_absT, vec_sz>(sycl::abs(data_vec));

sg.store<vec_sz>(result_ptrT(&result[start]), res_vec);
}
else
{
result[idx] = -1 * array1[idx];
for (size_t k = start + sg.get_local_id()[0]; k < size; k += max_sg_size)
{
result[k] = std::abs(array1[k]);
}
}
};

auto kernel_func = [&](sycl::handler& cgh) {
cgh.parallel_for<class dpnp_elemwise_absolute_c_kernel<_DataType>>(gws, kernel_parallel_for_func);
cgh.parallel_for<class dpnp_elemwise_absolute_c_kernel<_DataType_input, _DataType_output>>(
sycl::nd_range<1>(gws_range, lws_range), kernel_parallel_for_func);
};

event = q.submit(kernel_func);
}

input1_ptr.depends_on(event);
result1_ptr.depends_on(event);
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
}

Expand All @@ -182,28 +202,24 @@ void dpnp_elemwise_absolute_c(const void* input1_in, void* result1, size_t size)
{
DPCTLSyclQueueRef q_ref = reinterpret_cast<DPCTLSyclQueueRef>(&DPNP_QUEUE);
DPCTLEventVectorRef dep_event_vec_ref = nullptr;
DPCTLSyclEventRef event_ref = dpnp_elemwise_absolute_c<_DataType>(q_ref,
input1_in,
result1,
size,
dep_event_vec_ref);
DPCTLSyclEventRef event_ref = dpnp_elemwise_absolute_c<_DataType, _DataType>(q_ref,
input1_in,
result1,
size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType>
void (*dpnp_elemwise_absolute_default_c)(const void*, void*, size_t) = dpnp_elemwise_absolute_c<_DataType>;

template <typename _DataType>
template <typename _DataType_input, typename _DataType_output = _DataType_input>
DPCTLSyclEventRef (*dpnp_elemwise_absolute_ext_c)(DPCTLSyclQueueRef,
const void*,
void*,
size_t,
const DPCTLEventVectorRef) = dpnp_elemwise_absolute_c<_DataType>;

// template void dpnp_elemwise_absolute_c<double>(void* array1_in, void* result1, size_t size);
// template void dpnp_elemwise_absolute_c<float>(void* array1_in, void* result1, size_t size);
// template void dpnp_elemwise_absolute_c<long>(void* array1_in, void* result1, size_t size);
// template void dpnp_elemwise_absolute_c<int>(void* array1_in, void* result1, size_t size);
const DPCTLEventVectorRef) = dpnp_elemwise_absolute_c<_DataType_input, _DataType_output>;

template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
DPCTLSyclEventRef dpnp_cross_c(DPCTLSyclQueueRef q_ref,
Expand Down Expand Up @@ -1085,10 +1101,12 @@ void func_map_init_mathematical(func_map_t& fmap)
(void*)dpnp_elemwise_absolute_ext_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_LNG][eft_LNG] = {eft_LNG,
(void*)dpnp_elemwise_absolute_ext_c<int64_t>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_FLT][eft_FLT] = {eft_FLT,
(void*)dpnp_elemwise_absolute_ext_c<float>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_DBL][eft_DBL] = {eft_DBL,
(void*)dpnp_elemwise_absolute_ext_c<double>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_elemwise_absolute_ext_c<float>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_elemwise_absolute_ext_c<double>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C64][eft_C64] = {
eft_FLT, (void*)dpnp_elemwise_absolute_ext_c<std::complex<float>, float>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C128][eft_C128] = {
eft_DBL, (void*)dpnp_elemwise_absolute_ext_c<std::complex<double>, double>};

fmap[DPNPFuncName::DPNP_FN_AROUND][eft_INT][eft_INT] = {eft_INT, (void*)dpnp_around_default_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_around_default_c<int64_t>};
Expand Down
6 changes: 6 additions & 0 deletions dpnp/backend/src/dpnp_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ struct is_any : std::disjunction<std::is_same<T, Ts>...> {};
template <typename T, typename... Ts>
struct are_same : std::conjunction<std::is_same<T, Ts>...> {};

/**
* A template constant to check if type T matces any type from Ts.
*/
template <typename T, typename... Ts>
constexpr auto is_any_v = is_any<T, Ts...>::value;

/**
* A template constat to check if both types T1 and T2 match every type from Ts sequence.
*/
Expand Down
59 changes: 40 additions & 19 deletions dpnp/dpnp_iface_mathematical.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,42 +117,63 @@ def abs(*args, **kwargs):
return dpnp.absolute(*args, **kwargs)


def absolute(x1, **kwargs):
def absolute(x,
/,
out=None,
*,
where=True,
dtype=None,
subok=True,
**kwargs):
"""
Calculate the absolute value element-wise.
For full documentation refer to :obj:`numpy.absolute`.
.. seealso:: :obj:`dpnp.abs` : Calculate the absolute value element-wise.
Returns
-------
y : dpnp.ndarray
An array containing the absolute value of each element in `x`.
Limitations
-----------
Parameter ``x1`` is supported as :obj:`dpnp.ndarray`.
Dimension of input array is limited by ``x1.ndim != 0``.
Keyword arguments ``kwargs`` are currently unsupported.
Otherwise the functions will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Parameters `x` is only supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`.
Parameters `out`, `where`, `dtype` and `subok` are supported with their default values.
Keyword arguments ``kwargs`` are currently unsupported.
Otherwise the function will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Examples
--------
>>> import dpnp as np
>>> a = np.array([-1.2, 1.2])
>>> result = np.absolute(a)
>>> import dpnp as dp
>>> a = dp.array([-1.2, 1.2])
>>> result = dp.absolute(a)
>>> [x for x in result]
[1.2, 1.2]
"""

x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
if x1_desc and not kwargs:
if not x1_desc.ndim:
pass
else:
result = dpnp_absolute(x1_desc).get_pyobj()

return result

return call_origin(numpy.absolute, x1, **kwargs)
if out is not None:
pass
elif where is not True:
pass
elif dtype is not None:
pass
elif subok is not True:
pass
elif dpnp.isscalar(x):
pass
else:
x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False)
if x_desc:
if x_desc.dtype == dpnp.bool:
# return a copy of input array "x"
return dpnp.array(x, dtype=x.dtype, sycl_queue=x.sycl_queue, usm_type=x.usm_type)
return dpnp_absolute(x_desc).get_pyobj()

return call_origin(numpy.absolute, x, out=out, where=where, dtype=dtype, subok=subok, **kwargs)


def add(x1,
Expand Down
53 changes: 44 additions & 9 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,48 @@
import dpnp


def get_complex_dtypes(device=None):
"""
Build a list of complex types supported by DPNP based on device capabilities.
"""

dev = dpctl.select_default_device() if device is None else device

# add complex types
dtypes = [dpnp.complex64]
if dev.has_aspect_fp64:
dtypes.append(dpnp.complex128)
return dtypes


def get_float_dtypes(no_float16=True,
device=None):
"""
Build a list of floating types supported by DPNP based on device capabilities.
"""

dev = dpctl.select_default_device() if device is None else device

# add floating types
dtypes = [dpnp.float16] if not no_float16 else []

dtypes.append(dpnp.float32)
if dev.has_aspect_fp64:
dtypes.append(dpnp.float64)
return dtypes


def get_float_complex_dtypes(no_float16=True,
device=None):
"""
Build a list of floating and complex types supported by DPNP based on device capabilities.
"""

dtypes = get_float_dtypes(no_float16, device)
dtypes.extend(get_complex_dtypes(device))
return dtypes


def get_all_dtypes(no_bool=False,
no_float16=True,
no_complex=False,
Expand All @@ -22,18 +64,11 @@ def get_all_dtypes(no_bool=False,
dtypes.extend([dpnp.int32, dpnp.int64])

# add floating types
if not no_float16 and dev.has_aspect_fp16:
dtypes.append(dpnp.float16)

dtypes.append(dpnp.float32)
if dev.has_aspect_fp64:
dtypes.append(dpnp.float64)
dtypes.extend(get_float_dtypes(dev))

# add complex types
if not no_complex:
dtypes.append(dpnp.complex64)
if dev.has_aspect_fp64:
dtypes.append(dpnp.complex128)
dtypes.extend(get_complex_dtypes(dev))

# add None value to validate a default dtype
if not no_none:
Expand Down
2 changes: 0 additions & 2 deletions tests/skipped_tests_gpu.tbl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ tests/test_random.py::TestPermutationsTestShuffle::test_no_miss_numbers[int64]
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.array([])]
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.astype(dpnp.asarray(x), dpnp.float32)]

tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-abs-data0]
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-ceil-data1]
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-conjugate-data2]
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-copy-data3]
Expand All @@ -22,7 +21,6 @@ tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-ediff1d-data7]
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-fabs-data8]
tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-floor-data9]

tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-abs-data0]
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-ceil-data1]
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-conjugate-data2]
tests/test_sycl_queue.py::test_1in_1out[level_zero:gpu:0-copy-data3]
Expand Down
Loading

0 comments on commit 648612d

Please sign in to comment.