Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 194 additions & 8 deletions sycl/include/CL/sycl/half_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,14 +386,200 @@ class half {
operator--();
return ret;
}
constexpr half &operator-() {
Data = -Data;
return *this;
}
constexpr half operator-() const {
half r = *this;
return -r;
}
__SYCL_CONSTEXPR_HALF friend half operator-(const half other) {
return half(-other.Data);
}

// Operator +, -, *, /
#define OP(op, op_eq) \
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
const half rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend double operator op(const half lhs, \
const double rhs) { \
double rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend double operator op(const double lhs, \
const half rhs) { \
double rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend float operator op(const half lhs, \
const float rhs) { \
float rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend float operator op(const float lhs, \
const half rhs) { \
float rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
const int rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const int lhs, \
const half rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
const long rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const long lhs, \
const half rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const half lhs, \
const long long rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const long long lhs, \
const half rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const half &lhs, \
const unsigned int &rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const unsigned int &lhs, \
const half &rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const half &lhs, \
const unsigned long &rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const unsigned long &lhs, \
const half &rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op( \
const half &lhs, const unsigned long long &rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
} \
__SYCL_CONSTEXPR_HALF friend half operator op(const unsigned long long &lhs, \
const half &rhs) { \
half rtn = lhs; \
rtn op_eq rhs; \
return rtn; \
}
OP(+, +=)
OP(-, -=)
OP(*, *=)
OP(/, /=)
Copy link
Contributor

@dkhaldi dkhaldi Apr 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add support for:
OP(bool, ==)
OP(bool, !=)
OP(bool, <)
OP(bool, >)
OP(bool, <=)
OP(bool, >=)

?
Like this, we will have a more complete list similar to what we have for bfloat16 support currently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the logical operators and expanded the test


#undef OP

// Operator ==, !=, <, >, <=, >=
#define OP(op) \
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
const half &rhs) { \
return lhs.Data op rhs.Data; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
const double &rhs) { \
return lhs.Data op rhs; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const double &lhs, \
const half &rhs) { \
return lhs op rhs.Data; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
const float &rhs) { \
return lhs.Data op rhs; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const float &lhs, \
const half &rhs) { \
return lhs op rhs.Data; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
const int &rhs) { \
return lhs.Data op rhs; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const int &lhs, \
const half &rhs) { \
return lhs op rhs.Data; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
const long &rhs) { \
return lhs.Data op rhs; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const long &lhs, \
const half &rhs) { \
return lhs op rhs.Data; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
const long long &rhs) { \
return lhs.Data op rhs; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const long long &lhs, \
const half &rhs) { \
return lhs op rhs.Data; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
const unsigned int &rhs) { \
return lhs.Data op rhs; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const unsigned int &lhs, \
const half &rhs) { \
return lhs op rhs.Data; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const half &lhs, \
const unsigned long &rhs) { \
return lhs.Data op rhs; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const unsigned long &lhs, \
const half &rhs) { \
return lhs op rhs.Data; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op( \
const half &lhs, const unsigned long long &rhs) { \
return lhs.Data op rhs; \
} \
__SYCL_CONSTEXPR_HALF friend bool operator op(const unsigned long long &lhs, \
const half &rhs) { \
return lhs op rhs.Data; \
}
OP(==)
OP(!=)
OP(<)
OP(>)
OP(<=)
OP(>=)

#undef OP

// Operator float
__SYCL_CONSTEXPR_HALF operator float() const {
return static_cast<float>(Data);
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/builtins_math.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ template <typename T> inline T __cospi(T x) { return std::cos(M_PI * x); }
template <typename T> T inline __fract(T x, T *iptr) {
T f = std::floor(x);
*(iptr) = f;
return std::fmin(x - f, nextafter(T(1.0), T(0.0)));
return std::fmin(x - f, std::nextafter(T(1.0), T(0.0)));
}

template <typename T> inline T __lgamma_r(T x, s::cl_int *signp) {
Expand Down
102 changes: 102 additions & 0 deletions sycl/test/type_traits/half_operator_types.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// RUN: %clangxx -fsycl %s -o %t.out
//==-------------- type_traits.cpp - SYCL type_traits test -----------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <CL/sycl.hpp>
using namespace std;

template <typename T1, typename T_rtn> void math_operator_helper() {
static_assert(
is_same_v<decltype(declval<T1>() + declval<sycl::half>()), T_rtn>);
static_assert(
is_same_v<decltype(declval<T1>() - declval<sycl::half>()), T_rtn>);
static_assert(
is_same_v<decltype(declval<T1>() * declval<sycl::half>()), T_rtn>);
static_assert(
is_same_v<decltype(declval<T1>() / declval<sycl::half>()), T_rtn>);

static_assert(
is_same_v<decltype(declval<sycl::half>() + declval<T1>()), T_rtn>);
static_assert(
is_same_v<decltype(declval<sycl::half>() - declval<T1>()), T_rtn>);
static_assert(
is_same_v<decltype(declval<sycl::half>() * declval<T1>()), T_rtn>);
static_assert(
is_same_v<decltype(declval<sycl::half>() / declval<T1>()), T_rtn>);
}

template <typename T1> void logical_operator_helper() {
static_assert(
is_same_v<decltype(declval<T1>() == declval<sycl::half>()), bool>);
static_assert(
is_same_v<decltype(declval<T1>() != declval<sycl::half>()), bool>);
static_assert(
is_same_v<decltype(declval<T1>() > declval<sycl::half>()), bool>);
static_assert(
is_same_v<decltype(declval<T1>() < declval<sycl::half>()), bool>);
static_assert(
is_same_v<decltype(declval<T1>() <= declval<sycl::half>()), bool>);
static_assert(
is_same_v<decltype(declval<T1>() >= declval<sycl::half>()), bool>);

static_assert(
is_same_v<decltype(declval<sycl::half>() == declval<T1>()), bool>);
static_assert(
is_same_v<decltype(declval<sycl::half>() != declval<T1>()), bool>);
static_assert(
is_same_v<decltype(declval<sycl::half>() > declval<T1>()), bool>);
static_assert(
is_same_v<decltype(declval<sycl::half>() < declval<T1>()), bool>);
static_assert(
is_same_v<decltype(declval<sycl::half>() <= declval<T1>()), bool>);
static_assert(
is_same_v<decltype(declval<sycl::half>() >= declval<T1>()), bool>);
}

template <typename T1, typename T_rtn>
void check_half_math_operator_types(sycl::queue &Queue) {

// Test on host
math_operator_helper<T1, T_rtn>();

// Test on device
Queue.submit([&](sycl::handler &cgh) {
cgh.single_task([=] { math_operator_helper<T1, T_rtn>(); });
});
}

template <typename T1>
void check_half_logical_operator_types(sycl::queue &Queue) {

// Test on host
logical_operator_helper<T1>();

// Test on device
Queue.submit([&](sycl::handler &cgh) {
cgh.single_task([=] { logical_operator_helper<T1>(); });
});
}

int main() {

sycl::queue Queue;

check_half_math_operator_types<sycl::half, sycl::half>(Queue);
check_half_math_operator_types<double, double>(Queue);
check_half_math_operator_types<float, float>(Queue);
check_half_math_operator_types<int, sycl::half>(Queue);
check_half_math_operator_types<long, sycl::half>(Queue);
check_half_math_operator_types<long long, sycl::half>(Queue);

check_half_logical_operator_types<sycl::half>(Queue);
check_half_logical_operator_types<double>(Queue);
check_half_logical_operator_types<float>(Queue);
check_half_logical_operator_types<int>(Queue);
check_half_logical_operator_types<long>(Queue);
check_half_logical_operator_types<long long>(Queue);
}