Skip to content

Commit

Permalink
Elementwise functions cbrt, exp2, copysign, and rsqrt (#1443)
Browse files Browse the repository at this point in the history
* Implements dpctl.tensor.cbrt

* Implements copysign and exp2 elementwise funcs

* Adds tests for cbrt, copysign, exp2

* Implements rsqrt and tests for rsqrt

* Modified tests for cbrt, copysign, and rsqrt

Now test more type combinations/output types
  • Loading branch information
ndgrigorian authored Oct 17, 2023
1 parent af04d34 commit 2d2f235
Show file tree
Hide file tree
Showing 11 changed files with 1,608 additions and 1 deletion.
8 changes: 8 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,16 @@
bitwise_or,
bitwise_right_shift,
bitwise_xor,
cbrt,
ceil,
conj,
copysign,
cos,
cosh,
divide,
equal,
exp,
exp2,
expm1,
floor,
floor_divide,
Expand Down Expand Up @@ -149,6 +152,7 @@
real,
remainder,
round,
rsqrt,
sign,
signbit,
sin,
Expand Down Expand Up @@ -314,4 +318,8 @@
"argmax",
"argmin",
"prod",
"cbrt",
"exp2",
"copysign",
"rsqrt",
]
113 changes: 113 additions & 0 deletions dpctl/tensor/_elementwise_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1761,3 +1761,116 @@
hypot = BinaryElementwiseFunc(
"hypot", ti._hypot_result_type, ti._hypot, _hypot_docstring_
)


# U37: ==== CBRT (x)
_cbrt_docstring_ = """
cbrt(x, out=None, order='K')
Computes positive cube-root for each element `x_i` for input array `x`.
Args:
x (usm_ndarray):
Input array, expected to have a real floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise positive cube-root.
The data type of the returned array is determined by
the Type Promotion Rules.
"""

cbrt = UnaryElementwiseFunc(
"cbrt", ti._cbrt_result_type, ti._cbrt, _cbrt_docstring_
)


# U38: ==== EXP2 (x)
_exp2_docstring_ = """
exp2(x, out=None, order='K')
Computes the base-2 exponential for each element `x_i` for input array `x`.
Args:
x (usm_ndarray):
Input array, expected to have a floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise base-2 exponentials.
The data type of the returned array is determined by
the Type Promotion Rules.
"""

exp2 = UnaryElementwiseFunc(
"exp2", ti._exp2_result_type, ti._exp2, _exp2_docstring_
)


# B25: ==== COPYSIGN (x1, x2)
_copysign_docstring_ = """
copysign(x1, x2, out=None, order='K')
Composes a floating-point value with the magnitude of `x1_i` and the sign of
`x2_i` for each element of input arrays `x1` and `x2`.
Args:
x1 (usm_ndarray):
First input array, expected to have a real floating-point data type.
x2 (usm_ndarray):
Second input array, also expected to have a real floating-point data
type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise results. The data type
of the returned array is determined by the Type Promotion Rules.
"""
copysign = BinaryElementwiseFunc(
"copysign",
ti._copysign_result_type,
ti._copysign,
_copysign_docstring_,
)


# U39: ==== RSQRT (x)
_rsqrt_docstring_ = """
rsqrt(x, out=None, order='K')
Computes the reciprocal square-root for each element `x_i` for input array `x`.
Args:
x (usm_ndarray):
Input array, expected to have a real floating-point data type.
out ({None, usm_ndarray}, optional):
Output array to populate.
Array have the correct shape and the expected data type.
order ("C","F","A","K", optional):
Memory layout of the newly output array, if parameter `out` is `None`.
Default: "K".
Returns:
usm_narray:
An array containing the element-wise reciprocal square-root.
The data type of the returned array is determined by
the Type Promotion Rules.
"""

rsqrt = UnaryElementwiseFunc(
"rsqrt", ti._rsqrt_result_type, ti._rsqrt, _rsqrt_docstring_
)
172 changes: 172 additions & 0 deletions dpctl/tensor/libtensor/include/kernels/elementwise_functions/cbrt.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
//=== cbrt.hpp - Unary function CBRT ------ *-C++-*--/===//
//
// Data Parallel Control (dpctl)
//
// Copyright 2020-2023 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//===---------------------------------------------------------------------===//
///
/// \file
/// This file defines kernels for elementwise evaluation of CBRT(x)
/// function that compute a square root.
//===---------------------------------------------------------------------===//

#pragma once
#include <CL/sycl.hpp>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <type_traits>

#include "kernels/elementwise_functions/common.hpp"

#include "utils/offset_utils.hpp"
#include "utils/type_dispatch.hpp"
#include "utils/type_utils.hpp"
#include <pybind11/pybind11.h>

namespace dpctl
{
namespace tensor
{
namespace kernels
{
namespace cbrt
{

namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

template <typename argT, typename resT> struct CbrtFunctor
{

// is function constant for given argT
using is_constant = typename std::false_type;
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::true_type;

resT operator()(const argT &in) const
{
return sycl::cbrt(in);
}
};

template <typename argTy,
typename resTy = argTy,
unsigned int vec_sz = 4,
unsigned int n_vecs = 2>
using CbrtContigFunctor = elementwise_common::
UnaryContigFunctor<argTy, resTy, CbrtFunctor<argTy, resTy>, vec_sz, n_vecs>;

template <typename argTy, typename resTy, typename IndexerT>
using CbrtStridedFunctor = elementwise_common::
UnaryStridedFunctor<argTy, resTy, IndexerT, CbrtFunctor<argTy, resTy>>;

template <typename T> struct CbrtOutputType
{
using value_type = typename std::disjunction< // disjunction is C++17
// feature, supported by DPC++
td_ns::TypeMapResultEntry<T, sycl::half, sycl::half>,
td_ns::TypeMapResultEntry<T, float, float>,
td_ns::TypeMapResultEntry<T, double, double>,
td_ns::DefaultResultEntry<void>>::result_type;
};

template <typename T1, typename T2, unsigned int vec_sz, unsigned int n_vecs>
class cbrt_contig_kernel;

template <typename argTy>
sycl::event cbrt_contig_impl(sycl::queue &exec_q,
size_t nelems,
const char *arg_p,
char *res_p,
const std::vector<sycl::event> &depends = {})
{
return elementwise_common::unary_contig_impl<
argTy, CbrtOutputType, CbrtContigFunctor, cbrt_contig_kernel>(
exec_q, nelems, arg_p, res_p, depends);
}

template <typename fnT, typename T> struct CbrtContigFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename CbrtOutputType<T>::value_type,
void>) {
fnT fn = nullptr;
return fn;
}
else {
fnT fn = cbrt_contig_impl<T>;
return fn;
}
}
};

template <typename fnT, typename T> struct CbrtTypeMapFactory
{
/*! @brief get typeid for output type of std::cbrt(T x) */
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
{
using rT = typename CbrtOutputType<T>::value_type;
return td_ns::GetTypeid<rT>{}.get();
}
};

template <typename T1, typename T2, typename T3> class cbrt_strided_kernel;

template <typename argTy>
sycl::event
cbrt_strided_impl(sycl::queue &exec_q,
size_t nelems,
int nd,
const py::ssize_t *shape_and_strides,
const char *arg_p,
py::ssize_t arg_offset,
char *res_p,
py::ssize_t res_offset,
const std::vector<sycl::event> &depends,
const std::vector<sycl::event> &additional_depends)
{
return elementwise_common::unary_strided_impl<
argTy, CbrtOutputType, CbrtStridedFunctor, cbrt_strided_kernel>(
exec_q, nelems, nd, shape_and_strides, arg_p, arg_offset, res_p,
res_offset, depends, additional_depends);
}

template <typename fnT, typename T> struct CbrtStridedFactory
{
fnT get()
{
if constexpr (std::is_same_v<typename CbrtOutputType<T>::value_type,
void>) {
fnT fn = nullptr;
return fn;
}
else {
fnT fn = cbrt_strided_impl<T>;
return fn;
}
}
};

} // namespace cbrt
} // namespace kernels
} // namespace tensor
} // namespace dpctl
Loading

0 comments on commit 2d2f235

Please sign in to comment.