Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-write dpnp.abs #1575

Merged
merged 2 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions dpnp/backend/extensions/vm/abs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <CL/sycl.hpp>

#include "common.hpp"
#include "types_matrix.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace vm
{
template <typename T>
sycl::event abs_contig_impl(sycl::queue exec_q,
const std::int64_t n,
const char *in_a,
char *out_y,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

const T *a = reinterpret_cast<const T *>(in_a);
T *y = reinterpret_cast<T *>(out_y);

return mkl_vm::abs(exec_q,
n, // number of elements to be calculated
a, // pointer `a` containing input vector of size n
y, // pointer `y` to the output vector of size n
depends);
}

template <typename fnT, typename T>
struct AbsContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<
typename types::AbsOutputType<T>::value_type, void>)
{
return nullptr;
}
else {
return abs_contig_impl<T>;
}
}
};
} // namespace vm
} // namespace ext
} // namespace backend
} // namespace dpnp
17 changes: 17 additions & 0 deletions dpnp/backend/extensions/vm/types_matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ namespace vm
{
namespace types
{
/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::abs<T> function.
*
* @tparam T Type of input vector `a` and of result vector `y`.
*/
template <typename T>
struct AbsOutputType
{
using value_type = typename std::disjunction<
// TODO: Add complex type here after updating the dispatching to allow
// output type to be different than input
dpctl_td_ns::TypeMapResultEntry<T, double>,
dpctl_td_ns::TypeMapResultEntry<T, float>,
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
vtavana marked this conversation as resolved.
Show resolved Hide resolved
};

/**
* @brief A factory to define pairs of supported types for which
* MKL VM library provides support in oneapi::mkl::vm::acos<T> function.
Expand Down
30 changes: 30 additions & 0 deletions dpnp/backend/extensions/vm/vm_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "abs.hpp"
#include "acos.hpp"
#include "acosh.hpp"
#include "add.hpp"
Expand Down Expand Up @@ -66,6 +67,7 @@ namespace vm_ext = dpnp::backend::ext::vm;
using vm_ext::binary_impl_fn_ptr_t;
using vm_ext::unary_impl_fn_ptr_t;

static unary_impl_fn_ptr_t abs_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t acos_dispatch_vector[dpctl_td_ns::num_types];
static unary_impl_fn_ptr_t acosh_dispatch_vector[dpctl_td_ns::num_types];
static binary_impl_fn_ptr_t add_dispatch_vector[dpctl_td_ns::num_types];
Expand Down Expand Up @@ -99,6 +101,34 @@ PYBIND11_MODULE(_vm_impl, m)
using arrayT = dpctl::tensor::usm_ndarray;
using event_vecT = std::vector<sycl::event>;

// UnaryUfunc: ==== Abs(x) ====
{
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
vm_ext::AbsContigFactory>(
abs_dispatch_vector);

auto abs_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
const event_vecT &depends = {}) {
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
abs_dispatch_vector);
};
m.def("_abs", abs_pyapi,
"Call `abs` function from OneMKL VM library to compute "
"the absolute value of vector elements",
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
py::arg("depends") = py::list());

auto abs_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
arrayT dst) {
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
abs_dispatch_vector);
};
m.def("_mkl_abs_to_call", abs_need_to_call_pyapi,
"Check input arguments to answer if `abs` function from "
"OneMKL VM library can be used",
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
}

// UnaryUfunc: ==== Acos(x) ====
{
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
Expand Down
8 changes: 3 additions & 5 deletions dpnp/backend/include/dpnp_iface_fptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,9 @@
*/
enum class DPNPFuncName : size_t
{
DPNP_FN_NONE, /**< Very first element of the enumeration */
DPNP_FN_ABSOLUTE, /**< Used in numpy.absolute() impl */
DPNP_FN_ABSOLUTE_EXT, /**< Used in numpy.absolute() impl, requires extra
parameters */
DPNP_FN_ADD, /**< Used in numpy.add() impl */
DPNP_FN_NONE, /**< Very first element of the enumeration */
DPNP_FN_ABSOLUTE, /**< Used in numpy.absolute() impl */
DPNP_FN_ADD, /**< Used in numpy.add() impl */
DPNP_FN_ADD_EXT, /**< Used in numpy.add() impl, requires extra parameters */
DPNP_FN_ALL, /**< Used in numpy.all() impl */
DPNP_FN_ALLCLOSE, /**< Used in numpy.allclose() impl */
Expand Down
23 changes: 0 additions & 23 deletions dpnp/backend/kernels/dpnp_krnl_mathematical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,6 @@ template <typename _DataType>
void (*dpnp_elemwise_absolute_default_c)(const void *, void *, size_t) =
dpnp_elemwise_absolute_c<_DataType>;

template <typename _DataType_input, typename _DataType_output = _DataType_input>
DPCTLSyclEventRef (*dpnp_elemwise_absolute_ext_c)(DPCTLSyclQueueRef,
const void *,
void *,
size_t,
const DPCTLEventVectorRef) =
dpnp_elemwise_absolute_c<_DataType_input, _DataType_output>;

template <typename _DataType_output,
typename _DataType_input1,
typename _DataType_input2>
Expand Down Expand Up @@ -1151,21 +1143,6 @@ void func_map_init_mathematical(func_map_t &fmap)
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_elemwise_absolute_default_c<double>};

fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_elemwise_absolute_ext_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_LNG][eft_LNG] = {
eft_LNG, (void *)dpnp_elemwise_absolute_ext_c<int64_t>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_FLT][eft_FLT] = {
eft_FLT, (void *)dpnp_elemwise_absolute_ext_c<float>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_DBL][eft_DBL] = {
eft_DBL, (void *)dpnp_elemwise_absolute_ext_c<double>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C64][eft_C64] = {
eft_FLT,
(void *)dpnp_elemwise_absolute_ext_c<std::complex<float>, float>};
fmap[DPNPFuncName::DPNP_FN_ABSOLUTE_EXT][eft_C128][eft_C128] = {
eft_DBL,
(void *)dpnp_elemwise_absolute_ext_c<std::complex<double>, double>};

fmap[DPNPFuncName::DPNP_FN_AROUND][eft_INT][eft_INT] = {
eft_INT, (void *)dpnp_around_default_c<int32_t>};
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_LNG][eft_LNG] = {
Expand Down
2 changes: 0 additions & 2 deletions dpnp/dpnp_algo/dpnp_algo.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ from dpnp.dpnp_utils.dpnp_algo_utils cimport dpnp_descriptor

cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this namespace for Enum import
cdef enum DPNPFuncName "DPNPFuncName":
DPNP_FN_ABSOLUTE
DPNP_FN_ABSOLUTE_EXT
DPNP_FN_ALLCLOSE
DPNP_FN_ALLCLOSE_EXT
DPNP_FN_ARANGE
Expand Down
39 changes: 0 additions & 39 deletions dpnp/dpnp_algo/dpnp_algo_mathematical.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ and the rest of the library
# NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file

__all__ += [
"dpnp_absolute",
"dpnp_copysign",
"dpnp_cross",
"dpnp_cumprod",
Expand All @@ -59,9 +58,6 @@ __all__ += [
]


ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_custom_elemwise_absolute_1in_1out_t)(c_dpctl.DPCTLSyclQueueRef,
void * , void * , size_t,
const c_dpctl.DPCTLEventVectorRef)
ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_1in_2out_t)(c_dpctl.DPCTLSyclQueueRef,
void * , void * , void * , size_t,
const c_dpctl.DPCTLEventVectorRef)
Expand All @@ -70,41 +66,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*ftpr_custom_trapz_2in_1out_with_2size_t)(c_d
const c_dpctl.DPCTLEventVectorRef)


cpdef utils.dpnp_descriptor dpnp_absolute(utils.dpnp_descriptor x1):
cdef shape_type_c x1_shape = x1.shape
cdef size_t x1_shape_size = x1.ndim

# convert string type names (array.dtype) to C enum DPNPFuncType
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)

# get the FPTR data structure
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ABSOLUTE_EXT, param1_type, param1_type)

x1_obj = x1.get_array()

# ceate result array with type given by FPTR data
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(x1_shape,
kernel_data.return_type,
None,
device=x1_obj.sycl_device,
usm_type=x1_obj.usm_type,
sycl_queue=x1_obj.sycl_queue)

result_sycl_queue = result.get_array().sycl_queue

cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()

cdef fptr_custom_elemwise_absolute_1in_1out_t func = <fptr_custom_elemwise_absolute_1in_1out_t > kernel_data.ptr
# call FPTR function
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, x1.get_data(), result.get_data(), x1.size, NULL)

with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
c_dpctl.DPCTLEvent_Delete(event_ref)

return result


cpdef utils.dpnp_descriptor dpnp_copysign(utils.dpnp_descriptor x1_obj,
utils.dpnp_descriptor x2_obj,
object dtype=None,
Expand Down
58 changes: 58 additions & 0 deletions dpnp/dpnp_algo/dpnp_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

__all__ = [
"check_nd_call_func",
"dpnp_abs",
"dpnp_acos",
"dpnp_acosh",
"dpnp_add",
Expand Down Expand Up @@ -169,6 +170,63 @@ def check_nd_call_func(
)


_abs_docstring = """
abs(x, out=None, order='K')

Calculates the absolute value for each element `x_i` of input array `x`.

Args:
x (dpnp.ndarray):
Input array, expected to have numeric data type.
out ({None, dpnp.ndarray}, optional):
Output array to populate. Array must have the correct
shape and the expected data type.
order ("C","F","A","K", optional): memory layout of the new
output array, if parameter `out` is `None`.
Default: "K".
Return:
dpnp.ndarray:
An array containing the element-wise absolute values.
For complex input, the absolute value is its magnitude.
If `x` has a real-valued data type, the returned array has the
same data type as `x`. If `x` has a complex floating-point data type,
the returned array has a real-valued floating-point data type whose
precision matches the precision of `x`.
"""


def _call_abs(src, dst, sycl_queue, depends=None):
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""

if depends is None:
depends = []

if vmi._mkl_abs_to_call(sycl_queue, src, dst):
# call pybind11 extension for abs() function from OneMKL VM
return vmi._abs(sycl_queue, src, dst, depends)
return ti._abs(src, dst, sycl_queue, depends)


abs_func = UnaryElementwiseFunc(
"abs", ti._abs_result_type, _call_abs, _abs_docstring
)


def dpnp_abs(x, out=None, order="K"):
"""
Invokes abs() function from pybind11 extension of OneMKL VM if possible.

Otherwise fully relies on dpctl.tensor implementation for abs() function.

"""
# dpctl.tensor only works with usm_ndarray
x1_usm = dpnp.get_usm_ndarray(x)
out_usm = None if out is None else dpnp.get_usm_ndarray(out)

res_usm = abs_func(x1_usm, out=out_usm, order=order)
return dpnp_array._create_from_usm_ndarray(res_usm)


_acos_docstring = """
acos(x, out=None, order='K')

Expand Down
Loading