From 8817b7c6261dfba2fb13f2287f34bce170662ce6 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 23 Nov 2022 16:15:25 +0800 Subject: [PATCH 01/12] [SYCL] Add bfloat16 comparison utils based on libdevice bfloat16 support. Signed-off-by: jinge90 --- sycl/include/sycl/ext/intel/math.hpp | 1 + sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 205 ++++++++++++++++++ 2 files changed, 206 insertions(+) create mode 100644 sycl/include/sycl/ext/intel/math/imf_bf16.hpp diff --git a/sycl/include/sycl/ext/intel/math.hpp b/sycl/include/sycl/ext/intel/math.hpp index 6a394beb94d7f..d934e550f55ad 100644 --- a/sycl/include/sycl/ext/intel/math.hpp +++ b/sycl/include/sycl/ext/intel/math.hpp @@ -10,6 +10,7 @@ #pragma once #include +#include #include #include diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp new file mode 100644 index 0000000000000..92dcbefc11084 --- /dev/null +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -0,0 +1,205 @@ +//==-------------------- imf_bf16.hpp - bfloat16 utils ---------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// C++ APIs for bfloat16 util functions. +//===----------------------------------------------------------------------===// + +#pragma once +#include +#include + +using sycl_bfloat16 = sycl::ext::oneapi::experimental::bfloat16; +using _iml_bfloat16_internal = uint16_t; + +extern "C" { +float __imf_bfloat162float(_iml_bfloat16_internal); +_iml_bfloat16_internal __imf_float2bfloat16(float); +_iml_bfloat16_internal __imf_float2bfloat16_rd(float); +_iml_bfloat16_internal __imf_float2bfloat16_rn(float); +_iml_bfloat16_internal __imf_float2bfloat16_ru(float); +_iml_bfloat16_internal __imf_float2bfloat16_rz(float); +}; + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext { +namespace intel { +namespace math { + +// Need to ensure that sycl bfloat16 defined in bfloat16.hpp is compatible +// with uint16_t in layout. +#if __cplusplus >= 201703L +static_assert(sizeof(sycl_bfloat16) == sizeof(_iml_bfloat16_internal), + "sycl bfloat16 is not compatible with _iml_bfloat16_internal."); + +float bfloat162float(sycl_bfloat16 b) { + return __imf_bfloat162float(__builtin_bit_cast(_iml_bfloat16_internal, b)); +} + +sycl_bfloat16 float2bfloat16(float f) { + return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16(f)); +} + +sycl_bfloat16 float2bfloat16_rd(float f) { + return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rd(f)); +} + +sycl_bfloat16 float2bfloat16_rn(float f) { + return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rn(f)); +} + +sycl_bfloat16 float2bfloat16_ru(float f) { + return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_ru(f)); +} + +sycl_bfloat16 float2bfloat16_rz(float f) { + return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rz(f)); +} + +bool hisnan(sycl_bfloat b) { return sycl::isnan(bfloat162float(b)); } + +bool hisinf(sycl_bfloat b) { return sycl::isinf(bfloat162float(b)); } + +bool heq(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return false; + return __builtin_bit_cast(_iml_bfloat16_internal, b1) == + __builtin_bit_cast(_iml_bfloat16_internal, b2); +} + +bool hequ(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b1)) + return true; + return __builtin_bit_cast(_iml_bfloat16_internal, b1) == + __builtin_bit_cast(_iml_bfloat16_internal, b2); +} + +bool hge(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return false; + float bf1 = bfloat162float(b1); + float bf2 = bfloat162float(b2); + return (bf1 >= bf2); +} + +bool hgeu(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return true; + float bf1 = bfloat162float(b1); + float bf2 = bfloat162float(b2); + return (bf1 >= bf2); +} + +bool hgt(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return false; + float bf1 = bfloat162float(b1); + float bf2 = bfloat162float(b2); + return (bf1 > bf2); +} + +bool hgtu(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return true; + float bf1 = bfloat162float(b1); + float bf2 = bfloat162float(b2); + return (bf1 > bf2); +} + +bool hle(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return false; + float bf1 = bfloat162float(b1); + float bf2 = bfloat162float(b2); + return (bf1 <= bf2); +} + +bool hleu(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return true; + float bf1 = bfloat162float(b1); + float bf2 = bfloat162float(b2); + return (bf1 <= bf2); +} + +bool hlt(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return false; + float bf1 = bfloat162float(b1); + float bf2 = bfloat162float(b2); + return (bf1 < bf2); +} + +bool hltu(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return true; + float bf1 = bfloat162float(b1); + float bf2 = bfloat162float(b2); + return (bf1 < bf2); +} + +sycl_bfloat16 hmax(sycl_bfloat16 b1, sycl_bfloat16 b2) { + _iml_bfloat16_internal ibi = 0x7FC0; + if (hisnan(b1) && hisnan(b2)) + return __builtin_bit_cast(sycl_bfloat16, ibi); + if (hisnan(b1)) + return b2; + else if (hisnan(b2)) + return b1; + else { + return (hgt(b1, b2) ? b1 : b2); + } +} + +sycl_bfloat16 hmax_nan(sycl_bfloat16 b1, sycl_bfloat16 b2) { + _iml_bfloat16_internal ibi = 0x7FC0; + if (hisnan(b1) || hisnan(b2)) + return __builtin_bit_cast(sycl_bfloat16, ibi); + else + return (hgt(b1, b2) ? b1 : b2); +} + +sycl_bfloat16 hmin(sycl_bfloat16 b1, sycl_bfloat16 b2) { + _iml_bfloat16_internal ibi = 0x7FC0; + if (hisnan(b1) && hisnan(b2)) + return __builtin_bit_cast(sycl_bfloat16, ibi); + if (hisnan(b1)) + return b2; + else if (hisnan(b2)) + return b1; + else { + return (hlt(b1, b2) ? b1 : b2); + } +} + +sycl_bfloat16 hmin_nan(sycl_bfloat16 b1, sycl_bfloat16 b2) { + _iml_bfloat16_internal ibi = 0x7FC0; + if (hisnan(b1) || hisnan(b2)) + return __builtin_bit_cast(sycl_bfloat16, ibi); + else + return (hlt(b1, b2) ? b1 : b2); +} + +bool hne(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return false; + return __builtin_bit_cast(_iml_bfloat16_internal, b1) != + __builtin_bit_cast(_iml_bfloat16_internal, b2); +} + +bool hneu(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return true; + return __builtin_bit_cast(_iml_bfloat16_internal, b1) != + __builtin_bit_cast(_iml_bfloat16_internal, b2); +} +#endif +} // namespace math +} // namespace intel +} // namespace ext +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl From c4c7da08bc54c09c0f857d256728edf56a6dd938 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 24 Nov 2022 16:09:17 +0800 Subject: [PATCH 02/12] remove unnecessary alias for uint16_t Signed-off-by: jinge90 --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 90 +++++++++++-------- 1 file changed, 53 insertions(+), 37 deletions(-) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index 92dcbefc11084..5e2d993a7a56b 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -13,15 +13,14 @@ #include using sycl_bfloat16 = sycl::ext::oneapi::experimental::bfloat16; -using _iml_bfloat16_internal = uint16_t; extern "C" { -float __imf_bfloat162float(_iml_bfloat16_internal); -_iml_bfloat16_internal __imf_float2bfloat16(float); -_iml_bfloat16_internal __imf_float2bfloat16_rd(float); -_iml_bfloat16_internal __imf_float2bfloat16_rn(float); -_iml_bfloat16_internal __imf_float2bfloat16_ru(float); -_iml_bfloat16_internal __imf_float2bfloat16_rz(float); +float __imf_bfloat162float(uint16_t); +uint16_t __imf_float2bfloat16(float); +uint16_t __imf_float2bfloat16_rd(float); +uint16_t __imf_float2bfloat16_rn(float); +uint16_t __imf_float2bfloat16_ru(float); +uint16_t __imf_float2bfloat16_rz(float); }; namespace sycl { @@ -33,11 +32,11 @@ namespace math { // Need to ensure that sycl bfloat16 defined in bfloat16.hpp is compatible // with uint16_t in layout. #if __cplusplus >= 201703L -static_assert(sizeof(sycl_bfloat16) == sizeof(_iml_bfloat16_internal), - "sycl bfloat16 is not compatible with _iml_bfloat16_internal."); +static_assert(sizeof(sycl_bfloat16) == sizeof(uint16_t), + "sycl bfloat16 is not compatible with uint16_t."); float bfloat162float(sycl_bfloat16 b) { - return __imf_bfloat162float(__builtin_bit_cast(_iml_bfloat16_internal, b)); + return __imf_bfloat162float(__builtin_bit_cast(uint16_t, b)); } sycl_bfloat16 float2bfloat16(float f) { @@ -60,22 +59,36 @@ sycl_bfloat16 float2bfloat16_rz(float f) { return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rz(f)); } -bool hisnan(sycl_bfloat b) { return sycl::isnan(bfloat162float(b)); } +bool hisnan(sycl_bfloat16 b) { return sycl::isnan(bfloat162float(b)); } -bool hisinf(sycl_bfloat b) { return sycl::isinf(bfloat162float(b)); } +bool hisinf(sycl_bfloat16 b) { return sycl::isinf(bfloat162float(b)); } bool heq(sycl_bfloat16 b1, sycl_bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; - return __builtin_bit_cast(_iml_bfloat16_internal, b1) == - __builtin_bit_cast(_iml_bfloat16_internal, b2); + return __builtin_bit_cast(uint16_t, b1) == + __builtin_bit_cast(uint16_t, b2); } bool hequ(sycl_bfloat16 b1, sycl_bfloat16 b2) { if (hisnan(b1) || hisnan(b1)) return true; - return __builtin_bit_cast(_iml_bfloat16_internal, b1) == - __builtin_bit_cast(_iml_bfloat16_internal, b2); + return __builtin_bit_cast(uint16_t, b1) == + __builtin_bit_cast(uint16_t, b2); +} + +bool hne(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return false; + return __builtin_bit_cast(uint16_t, b1) != + __builtin_bit_cast(uint16_t, b2); +} + +bool hneu(sycl_bfloat16 b1, sycl_bfloat16 b2) { + if (hisnan(b1) || hisnan(b2)) + return true; + return __builtin_bit_cast(uint16_t, b1) != + __builtin_bit_cast(uint16_t, b2); } bool hge(sycl_bfloat16 b1, sycl_bfloat16 b2) { @@ -143,60 +156,63 @@ bool hltu(sycl_bfloat16 b1, sycl_bfloat16 b2) { } sycl_bfloat16 hmax(sycl_bfloat16 b1, sycl_bfloat16 b2) { - _iml_bfloat16_internal ibi = 0x7FC0; + uint16_t canonical_nan = 0x7FC0; + uint16_t b1a = __builtin_bit_cast(uint16_t, b1); + uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) && hisnan(b2)) - return __builtin_bit_cast(sycl_bfloat16, ibi); + return __builtin_bit_cast(sycl_bfloat16, canonical_nan); if (hisnan(b1)) return b2; else if (hisnan(b2)) return b1; + else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) + return __builtin_bit_cast(sycl_bfloat16, static_cast(0x0)); else { return (hgt(b1, b2) ? b1 : b2); } } sycl_bfloat16 hmax_nan(sycl_bfloat16 b1, sycl_bfloat16 b2) { - _iml_bfloat16_internal ibi = 0x7FC0; + uint16_t canonical_nan = 0x7FC0; + uint16_t b1a = __builtin_bit_cast(uint16_t, b1); + uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) || hisnan(b2)) - return __builtin_bit_cast(sycl_bfloat16, ibi); + return __builtin_bit_cast(sycl_bfloat16, canonical_nan); + else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) + return __builtin_bit_cast(sycl_bfloat16, static_cast(0x0)); else return (hgt(b1, b2) ? b1 : b2); } sycl_bfloat16 hmin(sycl_bfloat16 b1, sycl_bfloat16 b2) { - _iml_bfloat16_internal ibi = 0x7FC0; + uint16_t canonical_nan = 0x7FC0; + uint16_t b1a = __builtin_bit_cast(uint16_t, b1); + uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) && hisnan(b2)) - return __builtin_bit_cast(sycl_bfloat16, ibi); + return __builtin_bit_cast(sycl_bfloat16, canonical_nan); if (hisnan(b1)) return b2; else if (hisnan(b2)) return b1; + else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) + return __builtin_bit_cast(sycl_bfloat16, static_cast(0x8000)); else { return (hlt(b1, b2) ? b1 : b2); } } sycl_bfloat16 hmin_nan(sycl_bfloat16 b1, sycl_bfloat16 b2) { - _iml_bfloat16_internal ibi = 0x7FC0; + uint16_t canonical_nan = 0x7FC0; + uint16_t b1a = __builtin_bit_cast(uint16_t, b1); + uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) || hisnan(b2)) - return __builtin_bit_cast(sycl_bfloat16, ibi); + return __builtin_bit_cast(sycl_bfloat16, canonical_nan); + else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) + return __builtin_bit_cast(sycl_bfloat16, static_cast(0x8000)); else return (hlt(b1, b2) ? b1 : b2); } -bool hne(sycl_bfloat16 b1, sycl_bfloat16 b2) { - if (hisnan(b1) || hisnan(b2)) - return false; - return __builtin_bit_cast(_iml_bfloat16_internal, b1) != - __builtin_bit_cast(_iml_bfloat16_internal, b2); -} - -bool hneu(sycl_bfloat16 b1, sycl_bfloat16 b2) { - if (hisnan(b1) || hisnan(b2)) - return true; - return __builtin_bit_cast(_iml_bfloat16_internal, b1) != - __builtin_bit_cast(_iml_bfloat16_internal, b2); -} #endif } // namespace math } // namespace intel From 44d0d6f835ffc0de1413626af587279463da7c66 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 6 Dec 2022 21:34:48 +0800 Subject: [PATCH 03/12] remove sycl_bfloat16 to avoid spoiling sycl namespace --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 128 +++++++++++------- 1 file changed, 80 insertions(+), 48 deletions(-) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index 5e2d993a7a56b..0aa25a969c443 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -12,8 +12,6 @@ #include #include -using sycl_bfloat16 = sycl::ext::oneapi::experimental::bfloat16; - extern "C" { float __imf_bfloat162float(uint16_t); uint16_t __imf_float2bfloat16(float); @@ -32,66 +30,77 @@ namespace math { // Need to ensure that sycl bfloat16 defined in bfloat16.hpp is compatible // with uint16_t in layout. #if __cplusplus >= 201703L -static_assert(sizeof(sycl_bfloat16) == sizeof(uint16_t), +static_assert(sizeof(sycl::ext::oneapi::experimental::bfloat16) == + sizeof(uint16_t), "sycl bfloat16 is not compatible with uint16_t."); -float bfloat162float(sycl_bfloat16 b) { +float bfloat162float(sycl::ext::oneapi::experimental::bfloat16 b) { return __imf_bfloat162float(__builtin_bit_cast(uint16_t, b)); } -sycl_bfloat16 float2bfloat16(float f) { - return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16(f)); +sycl::ext::oneapi::experimental::bfloat16 float2bfloat16(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + __imf_float2bfloat16(f)); } -sycl_bfloat16 float2bfloat16_rd(float f) { - return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rd(f)); +sycl::ext::oneapi::experimental::bfloat16 float2bfloat16_rd(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + __imf_float2bfloat16_rd(f)); } -sycl_bfloat16 float2bfloat16_rn(float f) { - return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rn(f)); +sycl::ext::oneapi::experimental::bfloat16 float2bfloat16_rn(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + __imf_float2bfloat16_rn(f)); } -sycl_bfloat16 float2bfloat16_ru(float f) { - return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_ru(f)); +sycl::ext::oneapi::experimental::bfloat16 float2bfloat16_ru(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + __imf_float2bfloat16_ru(f)); } -sycl_bfloat16 float2bfloat16_rz(float f) { - return __builtin_bit_cast(sycl_bfloat16, __imf_float2bfloat16_rz(f)); +sycl::ext::oneapi::experimental::bfloat16 float2bfloat16_rz(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + __imf_float2bfloat16_rz(f)); } -bool hisnan(sycl_bfloat16 b) { return sycl::isnan(bfloat162float(b)); } +bool hisnan(sycl::ext::oneapi::experimental::bfloat16 b) { + return sycl::isnan(bfloat162float(b)); +} -bool hisinf(sycl_bfloat16 b) { return sycl::isinf(bfloat162float(b)); } +bool hisinf(sycl::ext::oneapi::experimental::bfloat16 b) { + return sycl::isinf(bfloat162float(b)); +} -bool heq(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool heq(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; - return __builtin_bit_cast(uint16_t, b1) == - __builtin_bit_cast(uint16_t, b2); + return __builtin_bit_cast(uint16_t, b1) == __builtin_bit_cast(uint16_t, b2); } -bool hequ(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hequ(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b1)) return true; - return __builtin_bit_cast(uint16_t, b1) == - __builtin_bit_cast(uint16_t, b2); + return __builtin_bit_cast(uint16_t, b1) == __builtin_bit_cast(uint16_t, b2); } -bool hne(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hne(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; - return __builtin_bit_cast(uint16_t, b1) != - __builtin_bit_cast(uint16_t, b2); + return __builtin_bit_cast(uint16_t, b1) != __builtin_bit_cast(uint16_t, b2); } -bool hneu(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hneu(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; - return __builtin_bit_cast(uint16_t, b1) != - __builtin_bit_cast(uint16_t, b2); + return __builtin_bit_cast(uint16_t, b1) != __builtin_bit_cast(uint16_t, b2); } -bool hge(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hge(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; float bf1 = bfloat162float(b1); @@ -99,7 +108,8 @@ bool hge(sycl_bfloat16 b1, sycl_bfloat16 b2) { return (bf1 >= bf2); } -bool hgeu(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hgeu(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; float bf1 = bfloat162float(b1); @@ -107,7 +117,8 @@ bool hgeu(sycl_bfloat16 b1, sycl_bfloat16 b2) { return (bf1 >= bf2); } -bool hgt(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hgt(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; float bf1 = bfloat162float(b1); @@ -115,7 +126,8 @@ bool hgt(sycl_bfloat16 b1, sycl_bfloat16 b2) { return (bf1 > bf2); } -bool hgtu(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hgtu(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; float bf1 = bfloat162float(b1); @@ -123,7 +135,8 @@ bool hgtu(sycl_bfloat16 b1, sycl_bfloat16 b2) { return (bf1 > bf2); } -bool hle(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hle(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; float bf1 = bfloat162float(b1); @@ -131,7 +144,8 @@ bool hle(sycl_bfloat16 b1, sycl_bfloat16 b2) { return (bf1 <= bf2); } -bool hleu(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hleu(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; float bf1 = bfloat162float(b1); @@ -139,7 +153,8 @@ bool hleu(sycl_bfloat16 b1, sycl_bfloat16 b2) { return (bf1 <= bf2); } -bool hlt(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hlt(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; float bf1 = bfloat162float(b1); @@ -147,7 +162,8 @@ bool hlt(sycl_bfloat16 b1, sycl_bfloat16 b2) { return (bf1 < bf2); } -bool hltu(sycl_bfloat16 b1, sycl_bfloat16 b2) { +bool hltu(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; float bf1 = bfloat162float(b1); @@ -155,60 +171,76 @@ bool hltu(sycl_bfloat16 b1, sycl_bfloat16 b2) { return (bf1 < bf2); } -sycl_bfloat16 hmax(sycl_bfloat16 b1, sycl_bfloat16 b2) { +sycl::ext::oneapi::experimental::bfloat16 +hmax(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { uint16_t canonical_nan = 0x7FC0; uint16_t b1a = __builtin_bit_cast(uint16_t, b1); uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) && hisnan(b2)) - return __builtin_bit_cast(sycl_bfloat16, canonical_nan); + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + canonical_nan); if (hisnan(b1)) return b2; else if (hisnan(b2)) return b1; else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) - return __builtin_bit_cast(sycl_bfloat16, static_cast(0x0)); + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + static_cast(0x0)); else { return (hgt(b1, b2) ? b1 : b2); } } -sycl_bfloat16 hmax_nan(sycl_bfloat16 b1, sycl_bfloat16 b2) { +sycl::ext::oneapi::experimental::bfloat16 +hmax_nan(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { uint16_t canonical_nan = 0x7FC0; uint16_t b1a = __builtin_bit_cast(uint16_t, b1); uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) || hisnan(b2)) - return __builtin_bit_cast(sycl_bfloat16, canonical_nan); + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + canonical_nan); else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) - return __builtin_bit_cast(sycl_bfloat16, static_cast(0x0)); + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + static_cast(0x0)); else return (hgt(b1, b2) ? b1 : b2); } -sycl_bfloat16 hmin(sycl_bfloat16 b1, sycl_bfloat16 b2) { +sycl::ext::oneapi::experimental::bfloat16 +hmin(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { uint16_t canonical_nan = 0x7FC0; uint16_t b1a = __builtin_bit_cast(uint16_t, b1); uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) && hisnan(b2)) - return __builtin_bit_cast(sycl_bfloat16, canonical_nan); + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + canonical_nan); if (hisnan(b1)) return b2; else if (hisnan(b2)) return b1; else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) - return __builtin_bit_cast(sycl_bfloat16, static_cast(0x8000)); + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + static_cast(0x8000)); else { return (hlt(b1, b2) ? b1 : b2); } } -sycl_bfloat16 hmin_nan(sycl_bfloat16 b1, sycl_bfloat16 b2) { +sycl::ext::oneapi::experimental::bfloat16 +hmin_nan(sycl::ext::oneapi::experimental::bfloat16 b1, + sycl::ext::oneapi::experimental::bfloat16 b2) { uint16_t canonical_nan = 0x7FC0; uint16_t b1a = __builtin_bit_cast(uint16_t, b1); uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) || hisnan(b2)) - return __builtin_bit_cast(sycl_bfloat16, canonical_nan); + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + canonical_nan); else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) - return __builtin_bit_cast(sycl_bfloat16, static_cast(0x8000)); + return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + static_cast(0x8000)); else return (hlt(b1, b2) ? b1 : b2); } From b514825117deb59f7e6da123bdf4c28142d5d5a0 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 6 Dec 2022 21:50:23 +0800 Subject: [PATCH 04/12] fix clang format --- sycl/include/sycl/ext/intel/math.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/intel/math.hpp b/sycl/include/sycl/ext/intel/math.hpp index d934e550f55ad..c9f06f8b3b0ec 100644 --- a/sycl/include/sycl/ext/intel/math.hpp +++ b/sycl/include/sycl/ext/intel/math.hpp @@ -9,8 +9,8 @@ //===----------------------------------------------------------------------===// #pragma once -#include #include +#include #include #include From 7e36261a0e4fd3a9d07a4e5b7563d8c28903d460 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Tue, 6 Dec 2022 22:39:07 +0800 Subject: [PATCH 05/12] fix bfloat16 header --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index 0aa25a969c443..4423d65025f72 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -9,7 +9,7 @@ //===----------------------------------------------------------------------===// #pragma once -#include +#include #include extern "C" { From 4b313ae676a5075eb2bf733734b7d68aace2d5cf Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 7 Dec 2022 00:44:40 +0800 Subject: [PATCH 06/12] fix bfloat16 type namespace --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 105 +++++++----------- 1 file changed, 42 insertions(+), 63 deletions(-) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index 4423d65025f72..ed79961219949 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -30,77 +30,71 @@ namespace math { // Need to ensure that sycl bfloat16 defined in bfloat16.hpp is compatible // with uint16_t in layout. #if __cplusplus >= 201703L -static_assert(sizeof(sycl::ext::oneapi::experimental::bfloat16) == - sizeof(uint16_t), +static_assert(sizeof(sycl::ext::oneapi::bfloat16) == sizeof(uint16_t), "sycl bfloat16 is not compatible with uint16_t."); -float bfloat162float(sycl::ext::oneapi::experimental::bfloat16 b) { +float bfloat162float(sycl::ext::oneapi::bfloat16 b) { return __imf_bfloat162float(__builtin_bit_cast(uint16_t, b)); } -sycl::ext::oneapi::experimental::bfloat16 float2bfloat16(float f) { - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, +sycl::ext::oneapi::bfloat16 float2bfloat16(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, __imf_float2bfloat16(f)); } -sycl::ext::oneapi::experimental::bfloat16 float2bfloat16_rd(float f) { - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, +sycl::ext::oneapi::bfloat16 float2bfloat16_rd(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, __imf_float2bfloat16_rd(f)); } -sycl::ext::oneapi::experimental::bfloat16 float2bfloat16_rn(float f) { - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, +sycl::ext::oneapi::bfloat16 float2bfloat16_rn(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, __imf_float2bfloat16_rn(f)); } -sycl::ext::oneapi::experimental::bfloat16 float2bfloat16_ru(float f) { - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, +sycl::ext::oneapi::bfloat16 float2bfloat16_ru(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, __imf_float2bfloat16_ru(f)); } -sycl::ext::oneapi::experimental::bfloat16 float2bfloat16_rz(float f) { - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, +sycl::ext::oneapi::bfloat16 float2bfloat16_rz(float f) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, __imf_float2bfloat16_rz(f)); } -bool hisnan(sycl::ext::oneapi::experimental::bfloat16 b) { +bool hisnan(sycl::ext::oneapi::bfloat16 b) { return sycl::isnan(bfloat162float(b)); } -bool hisinf(sycl::ext::oneapi::experimental::bfloat16 b) { +bool hisinf(sycl::ext::oneapi::bfloat16 b) { return sycl::isinf(bfloat162float(b)); } -bool heq(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool heq(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; return __builtin_bit_cast(uint16_t, b1) == __builtin_bit_cast(uint16_t, b2); } -bool hequ(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hequ(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b1)) return true; return __builtin_bit_cast(uint16_t, b1) == __builtin_bit_cast(uint16_t, b2); } -bool hne(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hne(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; return __builtin_bit_cast(uint16_t, b1) != __builtin_bit_cast(uint16_t, b2); } -bool hneu(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hneu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; return __builtin_bit_cast(uint16_t, b1) != __builtin_bit_cast(uint16_t, b2); } -bool hge(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hge(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; float bf1 = bfloat162float(b1); @@ -108,8 +102,7 @@ bool hge(sycl::ext::oneapi::experimental::bfloat16 b1, return (bf1 >= bf2); } -bool hgeu(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hgeu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; float bf1 = bfloat162float(b1); @@ -117,8 +110,7 @@ bool hgeu(sycl::ext::oneapi::experimental::bfloat16 b1, return (bf1 >= bf2); } -bool hgt(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hgt(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; float bf1 = bfloat162float(b1); @@ -126,8 +118,7 @@ bool hgt(sycl::ext::oneapi::experimental::bfloat16 b1, return (bf1 > bf2); } -bool hgtu(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hgtu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; float bf1 = bfloat162float(b1); @@ -135,8 +126,7 @@ bool hgtu(sycl::ext::oneapi::experimental::bfloat16 b1, return (bf1 > bf2); } -bool hle(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hle(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; float bf1 = bfloat162float(b1); @@ -144,8 +134,7 @@ bool hle(sycl::ext::oneapi::experimental::bfloat16 b1, return (bf1 <= bf2); } -bool hleu(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hleu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; float bf1 = bfloat162float(b1); @@ -153,8 +142,7 @@ bool hleu(sycl::ext::oneapi::experimental::bfloat16 b1, return (bf1 <= bf2); } -bool hlt(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hlt(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return false; float bf1 = bfloat162float(b1); @@ -162,8 +150,7 @@ bool hlt(sycl::ext::oneapi::experimental::bfloat16 b1, return (bf1 < bf2); } -bool hltu(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +bool hltu(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { if (hisnan(b1) || hisnan(b2)) return true; float bf1 = bfloat162float(b1); @@ -171,75 +158,67 @@ bool hltu(sycl::ext::oneapi::experimental::bfloat16 b1, return (bf1 < bf2); } -sycl::ext::oneapi::experimental::bfloat16 -hmax(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +sycl::ext::oneapi::bfloat16 hmax(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { uint16_t canonical_nan = 0x7FC0; uint16_t b1a = __builtin_bit_cast(uint16_t, b1); uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) && hisnan(b2)) - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, - canonical_nan); + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); if (hisnan(b1)) return b2; else if (hisnan(b2)) return b1; else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, static_cast(0x0)); else { return (hgt(b1, b2) ? b1 : b2); } } -sycl::ext::oneapi::experimental::bfloat16 -hmax_nan(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +sycl::ext::oneapi::bfloat16 hmax_nan(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { uint16_t canonical_nan = 0x7FC0; uint16_t b1a = __builtin_bit_cast(uint16_t, b1); uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) || hisnan(b2)) - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, - canonical_nan); + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, static_cast(0x0)); else return (hgt(b1, b2) ? b1 : b2); } -sycl::ext::oneapi::experimental::bfloat16 -hmin(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +sycl::ext::oneapi::bfloat16 hmin(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { uint16_t canonical_nan = 0x7FC0; uint16_t b1a = __builtin_bit_cast(uint16_t, b1); uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) && hisnan(b2)) - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, - canonical_nan); + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); if (hisnan(b1)) return b2; else if (hisnan(b2)) return b1; else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, static_cast(0x8000)); else { return (hlt(b1, b2) ? b1 : b2); } } -sycl::ext::oneapi::experimental::bfloat16 -hmin_nan(sycl::ext::oneapi::experimental::bfloat16 b1, - sycl::ext::oneapi::experimental::bfloat16 b2) { +sycl::ext::oneapi::bfloat16 hmin_nan(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { uint16_t canonical_nan = 0x7FC0; uint16_t b1a = __builtin_bit_cast(uint16_t, b1); uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) || hisnan(b2)) - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, - canonical_nan); + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) - return __builtin_bit_cast(sycl::ext::oneapi::experimental::bfloat16, + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, static_cast(0x8000)); else return (hlt(b1, b2) ? b1 : b2); From 418ff33b1cf95a89369b8165f81eb4a66fb82dd3 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Thu, 8 Dec 2022 21:53:21 +0800 Subject: [PATCH 07/12] Add bfloat16 arithmetic utils --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 95 ++++++++++++++++++- 1 file changed, 92 insertions(+), 3 deletions(-) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index ed79961219949..840ba58fa287b 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -32,7 +32,7 @@ namespace math { #if __cplusplus >= 201703L static_assert(sizeof(sycl::ext::oneapi::bfloat16) == sizeof(uint16_t), "sycl bfloat16 is not compatible with uint16_t."); - +// Bfloat16 type cast utils float bfloat162float(sycl::ext::oneapi::bfloat16 b) { return __imf_bfloat162float(__builtin_bit_cast(uint16_t, b)); } @@ -62,12 +62,17 @@ sycl::ext::oneapi::bfloat16 float2bfloat16_rz(float f) { __imf_float2bfloat16_rz(f)); } +// Bfloat16 comparison utils bool hisnan(sycl::ext::oneapi::bfloat16 b) { - return sycl::isnan(bfloat162float(b)); + uint16_t bf16_bits = __builtin_bit_cast(uint16_t, b); + return (((bf16_bits & 0x7F80) == 0x7F80) && (bf16_bits & 0x7F)) ? true + : false; } bool hisinf(sycl::ext::oneapi::bfloat16 b) { - return sycl::isinf(bfloat162float(b)); + uint16_t bf16_bits = __builtin_bit_cast(uint16_t, b); + return (((bf16_bits & 0x7F80) == 0x7F80) && !(bf16_bits & 0x7F)) ? true + : false; } bool heq(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { @@ -224,6 +229,90 @@ sycl::ext::oneapi::bfloat16 hmin_nan(sycl::ext::oneapi::bfloat16 b1, return (hlt(b1, b2) ? b1 : b2); } +// Bfloat16 Arithmetic utils +sycl::ext::oneapi::bfloat16 hneg(sycl::ext::oneapi::bfloat16 b) { + uint16_t bf16_bits = __builtin_bit_cast(uint16_t, b); + return hisnan(b) ? b + : (__builtin_bit_cast(sycl::ext::oneapi::bfloat16, + (bf16_bits ^ 0x8000))); +} + +sycl::ext::oneapi::bfloat16 habs(sycl::ext::oneapi::bfloat16 b) { + uint16_t bf16_bits = __builtin_bit_cast(uint16_t, b); + return (hisnan(b) || !(bf_bits & 0x8000)) ? b : hneg(b); +} + +sycl::ext::oneapi::bfloat16 hadd(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { + return b1 + b2; +} + +sycl::ext::oneapi::bfloat16 hadd_sat(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { + float f1 = bfloat162float(b1); + float f2 = bfloat162float(b2); + return float2bfloat16(sycl::clamp((f1 + f2), 0.f, 1.0f)); +} + +sycl::ext::oneapi::bfloat16 hsub(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { + return b1 - b2; +} + +sycl::ext::oneapi::bfloat16 hsub_sat(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { + float f1 = bfloat162float(b1); + float f2 = bfloat162float(b2); + return float2bfloat16(sycl::clamp((f1 - f2), 0.f, 1.0f)); +} + +sycl::ext::oneapi::bfloat16 hmul(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { + return b1 * b2; +} + +sycl::ext::oneapi::bfloat16 hadd_sat(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { + float f1 = bfloat162float(b1); + float f2 = bfloat162float(b2); + return float2bfloat16(sycl::clamp((f1 * f2), 0.f, 1.0f)); +} + +sycl::ext::oneapi::bfloat16 hdiv(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2) { + return b1 / b2; +} + +sycl::ext::oneapi::bfloat16 hfma(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2, + sycl::ext::oneapi::bfloat16 b3) { + float f1 = bfloat162float(b1); + float f2 = bfloat162float(b2); + float f3 = bfloat162float(b3); + return float2bfloat16(sycl::fma(f1, f2, f3)); +} + +sycl::ext::oneapi::bfloat16 hfma_sat(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2, + sycl::ext::oneapi::bfloat16 b3) { + float f1 = bfloat162float(b1); + float f2 = bfloat162float(b2); + float f3 = bfloat162float(b3); + return float2bfloat16(sycl::clamp(sycl::fma(f1, f2, f3), 0.f, 1.0f)); +} + +sycl::ext::oneapi::bfloat16 hfma_relu(sycl::ext::oneapi::bfloat16 b1, + sycl::ext::oneapi::bfloat16 b2, + sycl::ext::oneapi::bfloat16 b3) { + float f1 = bfloat162float(b1); + float f2 = bfloat162float(b2); + float f3 = bfloat162float(b3); + float f4 = sycl::fma(f1, f2, f3); + if (sycl::isnan(f4)) + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + static_cast(0x7FC0)); + return (f4 < 0.f) ? float2bfloat16(0.f) : float2bfloat16(f4); +} #endif } // namespace math } // namespace intel From 4032477a6bfcc1caa1d77034f382eb4d42ba59dd Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 9 Dec 2022 13:19:27 +0800 Subject: [PATCH 08/12] Add arithemtic utils for bfloat16 --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 142 +++++++++++++++--- 1 file changed, 120 insertions(+), 22 deletions(-) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index 840ba58fa287b..2527a5a63d7ce 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -174,7 +174,8 @@ sycl::ext::oneapi::bfloat16 hmax(sycl::ext::oneapi::bfloat16 b1, return b2; else if (hisnan(b2)) return b1; - else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) + else if (((b1a | b2a) == static_cast(0x8000)) && + ((b1a & b2a) == 0x0)) return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, static_cast(0x0)); else { @@ -189,7 +190,8 @@ sycl::ext::oneapi::bfloat16 hmax_nan(sycl::ext::oneapi::bfloat16 b1, uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) || hisnan(b2)) return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); - else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) + else if (((b1a | b2a) == static_cast(0x8000)) && + ((b1a & b2a) == 0x0)) return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, static_cast(0x0)); else @@ -207,7 +209,8 @@ sycl::ext::oneapi::bfloat16 hmin(sycl::ext::oneapi::bfloat16 b1, return b2; else if (hisnan(b2)) return b1; - else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) + else if (((b1a | b2a) == static_cast(0x8000)) && + ((b1a & b2a) == 0x0)) return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, static_cast(0x8000)); else { @@ -222,7 +225,8 @@ sycl::ext::oneapi::bfloat16 hmin_nan(sycl::ext::oneapi::bfloat16 b1, uint16_t b2a = __builtin_bit_cast(uint16_t, b2); if (hisnan(b1) || hisnan(b2)) return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, canonical_nan); - else if (((b1a | b2a) == 0x8000) && ((b1a & b2a) == 0x0)) + else if (((b1a | b2a) == static_cast(0x8000)) && + ((b1a & b2a) == 0x0)) return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, static_cast(0x8000)); else @@ -232,14 +236,16 @@ sycl::ext::oneapi::bfloat16 hmin_nan(sycl::ext::oneapi::bfloat16 b1, // Bfloat16 Arithmetic utils sycl::ext::oneapi::bfloat16 hneg(sycl::ext::oneapi::bfloat16 b) { uint16_t bf16_bits = __builtin_bit_cast(uint16_t, b); - return hisnan(b) ? b - : (__builtin_bit_cast(sycl::ext::oneapi::bfloat16, - (bf16_bits ^ 0x8000))); + uint16_t bf16_bits_n = bf16_bits ^ static_cast(0x8000); + return hisnan(b) + ? b + : (__builtin_bit_cast(sycl::ext::oneapi::bfloat16, bf16_bits_n)); } sycl::ext::oneapi::bfloat16 habs(sycl::ext::oneapi::bfloat16 b) { uint16_t bf16_bits = __builtin_bit_cast(uint16_t, b); - return (hisnan(b) || !(bf_bits & 0x8000)) ? b : hneg(b); + return (hisnan(b) || !(bf16_bits & static_cast(0x8000))) ? b + : hneg(b); } sycl::ext::oneapi::bfloat16 hadd(sycl::ext::oneapi::bfloat16 b1, @@ -249,9 +255,10 @@ sycl::ext::oneapi::bfloat16 hadd(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 hadd_sat(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { - float f1 = bfloat162float(b1); - float f2 = bfloat162float(b2); - return float2bfloat16(sycl::clamp((f1 + f2), 0.f, 1.0f)); + float f = bfloat162float(b1) + bfloat162float(b2); + return sycl::isnan(f) ? __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + static_cast(0x0)) + : float2bfloat16(sycl::clamp(f, 0.f, 1.0f)); } sycl::ext::oneapi::bfloat16 hsub(sycl::ext::oneapi::bfloat16 b1, @@ -261,9 +268,10 @@ sycl::ext::oneapi::bfloat16 hsub(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 hsub_sat(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { - float f1 = bfloat162float(b1); - float f2 = bfloat162float(b2); - return float2bfloat16(sycl::clamp((f1 - f2), 0.f, 1.0f)); + float f = bfloat162float(b1) - bfloat162float(b2); + return sycl::isnan(f) ? __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + static_cast(0x0)) + : float2bfloat16(sycl::clamp(f, 0.f, 1.0f)); } sycl::ext::oneapi::bfloat16 hmul(sycl::ext::oneapi::bfloat16 b1, @@ -271,11 +279,12 @@ sycl::ext::oneapi::bfloat16 hmul(sycl::ext::oneapi::bfloat16 b1, return b1 * b2; } -sycl::ext::oneapi::bfloat16 hadd_sat(sycl::ext::oneapi::bfloat16 b1, +sycl::ext::oneapi::bfloat16 hmul_sat(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2) { - float f1 = bfloat162float(b1); - float f2 = bfloat162float(b2); - return float2bfloat16(sycl::clamp((f1 * f2), 0.f, 1.0f)); + float f = bfloat162float(b1) * bfloat162float(b2); + return sycl::isnan(f) ? __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + static_cast(0x0)) + : float2bfloat16(sycl::clamp(f, 0.f, 1.0f)); } sycl::ext::oneapi::bfloat16 hdiv(sycl::ext::oneapi::bfloat16 b1, @@ -295,10 +304,11 @@ sycl::ext::oneapi::bfloat16 hfma(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 hfma_sat(sycl::ext::oneapi::bfloat16 b1, sycl::ext::oneapi::bfloat16 b2, sycl::ext::oneapi::bfloat16 b3) { - float f1 = bfloat162float(b1); - float f2 = bfloat162float(b2); - float f3 = bfloat162float(b3); - return float2bfloat16(sycl::clamp(sycl::fma(f1, f2, f3), 0.f, 1.0f)); + float f = + sycl::fma(bfloat162float(b1), bfloat162float(b2), bfloat162float(b3)); + return sycl::isnan(f) ? __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + static_cast(0)) + : float2bfloat16(sycl::clamp(f, 0.f, 1.0f)); } sycl::ext::oneapi::bfloat16 hfma_relu(sycl::ext::oneapi::bfloat16 b1, @@ -313,6 +323,94 @@ sycl::ext::oneapi::bfloat16 hfma_relu(sycl::ext::oneapi::bfloat16 b1, static_cast(0x7FC0)); return (f4 < 0.f) ? float2bfloat16(0.f) : float2bfloat16(f4); } + +sycl::marray +habs2(sycl::marray b) { + sycl::marray res{habs(b[0]), habs(b[1])}; + return res; +} + +sycl::marray +hadd2(sycl::marray b1, + sycl::marray b2) { + return b1 + b2; +} + +sycl::marray +hadd2_sat(sycl::marray b1, + sycl::marray b2) { + sycl::marray res{hadd_sat(b1[0], b2[0]), + hadd_sat(b1[1], b2[1])}; + return res; +} + +sycl::marray +hsub2(sycl::marray b1, + sycl::marray b2) { + return b1 - b2; +} + +sycl::marray +hsub2_sat(sycl::marray b1, + sycl::marray b2) { + sycl::marray res{hsub_sat(b1[0], b2[0]), + hsub_sat(b1[1], b2[1])}; + return res; +} + +sycl::marray +hmul2(sycl::marray b1, + sycl::marray b2) { + return b1 * b2; +} + +sycl::marray +hmul2_sat(sycl::marray b1, + sycl::marray b2) { + sycl::marray res{hmul_sat(b1[0], b2[0]), + hmul_sat(b1[1], b2[1])}; + return res; +} + +sycl::marray +hdiv2(sycl::marray b1, + sycl::marray b2) { + return b1 / b2; +} + +sycl::marray +hneg2(sycl::marray b) { + sycl::marray res{hneg(b[0]), hneg(b[1])}; + return res; +} + +sycl::marray +hfma2(sycl::marray b1, + sycl::marray b2, + sycl::marray b3) { + sycl::marray res{hfma(b1[0], b2[0], b3[0]), + hfma(b1[1], b2[1], b3[1])}; + return res; +} + +sycl::marray +hfma2_sat(sycl::marray b1, + sycl::marray b2, + sycl::marray b3) { + sycl::marray res{ + hfma_sat(b1[0], b2[0], b3[0]), hfma_sat(b1[1], b2[1], b3[1])}; + return res; +} + +sycl::marray +hfma2_relu(sycl::marray b1, + sycl::marray b2, + sycl::marray b3) { + sycl::marray res{ + hfma_relu(b1[0], b2[0], b3[0]), hfma_relu(b1[1], b2[1], b3[1])}; + return res; +} + #endif } // namespace math } // namespace intel From 94b47e4e47f1e3c57efd18dc46b5c8100855dd50 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Fri, 9 Dec 2022 16:50:22 +0800 Subject: [PATCH 09/12] Add comparison utils for sycl::marray, 1st part --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index 2527a5a63d7ce..666da9195610b 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -233,6 +233,66 @@ sycl::ext::oneapi::bfloat16 hmin_nan(sycl::ext::oneapi::bfloat16 b1, return (hlt(b1, b2) ? b1 : b2); } +bool hbeq2(sycl::marray b1, + sycl::marray b2) { + return heq(b1[0], b2[0]) && heq(b1[1], b2[1]); +} + +bool hbequ2(sycl::marray b1, + sycl::marray b2) { + return hequ(b1[0], b2[0]) && hequ(b1[1], b2[1]); +} + +bool hbge2(sycl::marray b1, + sycl::marray b2) { + return hge(b1[0], b2[0]) && hge(b1[1], b2[1]); +} + +bool hbgeu2(sycl::marray b1, + sycl::marray b2) { + return hgeu(b1[0], b2[0]) && hgeu(b1[1], b2[1]); +} + +bool hbgt2(sycl::marray b1, + sycl::marray b2) { + return hgt(b1[0], b2[0]) && hgt(b1[1], b2[1]); +} + +bool hbgtu2(sycl::marray b1, + sycl::marray b2) { + return hgtu(b1[0], b2[0]) && hgtu(b1[1], b2[1]); +} + +bool hble2(sycl::marray b1, + sycl::marray b2) { + return hle(b1[0], b2[0]) && hle(b1[1], b2[1]); +} + +bool hbleu2(sycl::marray b1, + sycl::marray b2) { + return hleu(b1[0], b2[0]) && hleu(b1[1], b2[1]); +} + +bool hblt2(sycl::marray b1, + sycl::marray b2) { + return hlt(b1[0], b2[0]) && hlt(b1[1], b2[1]); +} + +bool hbltu2(sycl::marray b1, + sycl::marray b2) { + return hltu(b1[0], b2[0]) && hltu(b1[1], b2[1]); +} + +bool hbne2(sycl::marray b1, + sycl::marray b2) { + return hne(b1[0], b2[0]) && hne(b1[1], b2[1]); +} + +bool hbneu2(sycl::marray b1, + sycl::marray b2) { + return hneu(b1[0], b2[0]) && hneu(b1[1], b2[1]); +} + // Bfloat16 Arithmetic utils sycl::ext::oneapi::bfloat16 hneg(sycl::ext::oneapi::bfloat16 b) { uint16_t bf16_bits = __builtin_bit_cast(uint16_t, b); From 2afd5ed27d4c29bd0266c90dc40ebfd7795bba69 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 14 Dec 2022 11:10:15 +0800 Subject: [PATCH 10/12] Add bfloat162 comparison functions --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 273 ++++++++++++++++++ 1 file changed, 273 insertions(+) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index 666da9195610b..44b59efec08ca 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -293,6 +293,279 @@ bool hbneu2(sycl::marray b1, return hneu(b1[0], b2[0]) && hneu(b1[1], b2[1]); } +sycl::marray +heq2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = heq(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = heq(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned heq2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (heq(b1[0], b2[0])) + res |= 0xFFFF; + if (heq(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hequ2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hequ(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hequ(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hequ2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hequ(b1[0], b2[0])) + res |= 0xFFFF; + if (hequ(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hne2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hne(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hne(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hne2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hne2(b1[0], b2[0])) + res |= 0xFFFF; + if (hne2(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hneu2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hneu(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hneu(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hneu2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hneu(b1[0], b2[0])) + res |= 0xFFFF; + if (hneu(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hge2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hge(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hge(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hge2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hge(b1[0], b2[0])) + res |= 0xFFFF; + if (hge(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hgeu2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hgeu(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hgeu(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hgeu2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hgeu(b1[0], b2[0])) + res |= 0xFFFF; + if (hgeu(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hgt2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hgt(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hgt(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hgt2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hgt(b1[0], b2[0])) + res |= 0xFFFF; + if (hgt(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hgtu2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hgtu(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hgtu(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hgtu2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hgtu(b1[0], b2[0])) + res |= 0xFFFF; + if (hgtu(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hisnan2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hisnan(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hisnan(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +sycl::marray +hle2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hle(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hle(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hle2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hle(b1[0], b2[0])) + res |= 0xFFFF; + if (hle(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hleu2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hleu(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hleu(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hleu2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hleu(b1[0], b2[0])) + res |= 0xFFFF; + if (hleu(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hlt2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hlt(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hlt(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hlt2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hlt(b1[0], b2[0])) + res |= 0xFFFF; + if (hlt(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hltu2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hltu(b1[0], b2[0]) ? 1.0f : 0.f; + res[1] = hltu(b1[1], b2[0]) ? 1.0f : 0.f; + return res; +} + +unsigned hltu2_mask(sycl::marray b1, + sycl::marray b2) { + unsigned res = 0; + if (hltu(b1[0], b2[0])) + res |= 0xFFFF; + if (hltu(b1[1], b2[1])) + res |= 0xFFFF0000; + return res; +} + +sycl::marray +hmax2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hmax(b1[0], b2[0]); + res[1] = hmax(b1[1], b2[0]); + return res; +} + +sycl::marray +hmax2_nan(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hmax_nan(b1[0], b2[0]); + res[1] = hmax_nan(b1[1], b2[0]); + return res; +} + +sycl::marray +hmin2(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hmin(b1[0], b2[0]); + res[1] = hmin(b1[1], b2[0]); + return res; +} + +sycl::marray +hmin2_nan(sycl::marray b1, + sycl::marray b2) { + sycl::marray res; + res[0] = hmin_nan(b1[0], b2[0]); + res[1] = hmin_nan(b1[1], b2[0]); + return res; +} + // Bfloat16 Arithmetic utils sycl::ext::oneapi::bfloat16 hneg(sycl::ext::oneapi::bfloat16 b) { uint16_t bf16_bits = __builtin_bit_cast(uint16_t, b); From 46a26222deea05b2eb0d30bf4502684c5dc7dcb6 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Wed, 14 Dec 2022 15:51:19 +0800 Subject: [PATCH 11/12] add conversion utils between float2 and bfloat162 --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 52 ++++++++++++++++--- 1 file changed, 46 insertions(+), 6 deletions(-) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index 44b59efec08ca..6e903a39ace89 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -62,6 +62,47 @@ sycl::ext::oneapi::bfloat16 float2bfloat16_rz(float f) { __imf_float2bfloat16_rz(f)); } +sycl::float2 bfloat1622float2(sycl::marray b) { + return sycl::float2{bfloat162float(b[0]), bfloat162float(b[1])}; +} + +sycl::marray +bfloat162bfloat162(sycl::ext::oneapi::bfloat16 b) { + sycl::marray res; + res[0] = res[1] = b; + return res; +} + +sycl::marray +halves2bfloat162(sycl::ext::oneapi::bfloat16 a, sycl::ext::oneapi::bfloat16 b) { + sycl::marray res; + res[0] = a; + res[1] = b; + return res; +} + +sycl::marray +float22bfloat162_rn(sycl::float2 f) { + sycl::marray res; + res[0] = float2bfloat16_rn(f.s0()); + res[1] = float2bfloat16_rn(f.s1()); + return res; +} + +sycl::marray float2bfloat162_rn(float f) { + sycl::marray res; + res[0] = res[1] = float2bfloat16_rn(f); + return res; +} + +sycl::marray floats2bfloat162_rn(float a, + float b) { + sycl::marray res; + res[0] = float2bfloat16_rn(a); + res[1] = float2bfloat16_rn(b); + return res; +} + // Bfloat16 comparison utils bool hisnan(sycl::ext::oneapi::bfloat16 b) { uint16_t bf16_bits = __builtin_bit_cast(uint16_t, b); @@ -343,9 +384,9 @@ hne2(sycl::marray b1, unsigned hne2_mask(sycl::marray b1, sycl::marray b2) { unsigned res = 0; - if (hne2(b1[0], b2[0])) + if (hne(b1[0], b2[0])) res |= 0xFFFF; - if (hne2(b1[1], b2[1])) + if (hne(b1[1], b2[1])) res |= 0xFFFF0000; return res; } @@ -446,11 +487,10 @@ unsigned hgtu2_mask(sycl::marray b1, } sycl::marray -hisnan2(sycl::marray b1, - sycl::marray b2) { +hisnan2(sycl::marray b) { sycl::marray res; - res[0] = hisnan(b1[0], b2[0]) ? 1.0f : 0.f; - res[1] = hisnan(b1[1], b2[0]) ? 1.0f : 0.f; + res[0] = hisnan(b[0]) ? 1.0f : 0.f; + res[1] = hisnan(b[1]) ? 1.0f : 0.f; return res; } From c3d9bb371f74020bedc454ffb7c3430a9dbbf9a4 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 19 Dec 2022 11:40:33 +0800 Subject: [PATCH 12/12] add hceil, hfloor, htrunc, hsqrt, hrsqrt, hrint --- sycl/include/sycl/ext/intel/math/imf_bf16.hpp | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp index 6e903a39ace89..5ab14ff5285a6 100644 --- a/sycl/include/sycl/ext/intel/math/imf_bf16.hpp +++ b/sycl/include/sycl/ext/intel/math/imf_bf16.hpp @@ -19,6 +19,12 @@ uint16_t __imf_float2bfloat16_rd(float); uint16_t __imf_float2bfloat16_rn(float); uint16_t __imf_float2bfloat16_ru(float); uint16_t __imf_float2bfloat16_rz(float); +uint16_t __imf_ceilbf16(uint16_t); +uint16_t __imf_floorbf16(uint16_t); +uint16_t __imf_truncbf16(uint16_t); +uint16_t __imf_rintbf16(uint16_t); +uint16_t __imf_sqrtbf16(uint16_t); +uint16_t __imf_rsqrtbf16(uint16_t); }; namespace sycl { @@ -784,6 +790,85 @@ hfma2_relu(sycl::marray b1, return res; } +// Bfloat16 math utils +sycl::ext::oneapi::bfloat16 hceil(sycl::ext::oneapi::bfloat16 b) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + __imf_ceilbf16(__builtin_bit_cast(uint16_t, b))); +} + +sycl::ext::oneapi::bfloat16 hfloor(sycl::ext::oneapi::bfloat16 b) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + __imf_floorbf16(__builtin_bit_cast(uint16_t, b))); +} + +sycl::ext::oneapi::bfloat16 htrunc(sycl::ext::oneapi::bfloat16 b) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + __imf_truncbf16(__builtin_bit_cast(uint16_t, b))); +} + +sycl::ext::oneapi::bfloat16 hrint(sycl::ext::oneapi::bfloat16 b) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + __imf_rintbf16(__builtin_bit_cast(uint16_t, b))); +} + +sycl::ext::oneapi::bfloat16 hsqrt(sycl::ext::oneapi::bfloat16 b) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + __imf_sqrtbf16(__builtin_bit_cast(uint16_t, b))); +} + +sycl::ext::oneapi::bfloat16 hrsqrt(sycl::ext::oneapi::bfloat16 b) { + return __builtin_bit_cast(sycl::ext::oneapi::bfloat16, + __imf_rsqrtbf16(__builtin_bit_cast(uint16_t, b))); +} + +sycl::marray +h2ceil(sycl::marray b) { + sycl::marray res; + res[0] = hceil(b[0]); + res[1] = hceil(b[1]); + return res; +} + +sycl::marray +h2floor(sycl::marray b) { + sycl::marray res; + res[0] = hfloor(b[0]); + res[1] = hfloor(b[1]); + return res; +} + +sycl::marray +h2trunc(sycl::marray b) { + sycl::marray res; + res[0] = htrunc(b[0]); + res[1] = htrunc(b[1]); + return res; +} + +sycl::marray +h2rint(sycl::marray b) { + sycl::marray res; + res[0] = hrint(b[0]); + res[1] = hrint(b[1]); + return res; +} + +sycl::marray +h2sqrt(sycl::marray b) { + sycl::marray res; + res[0] = hsqrt(b[0]); + res[1] = hsqrt(b[1]); + return res; +} + +sycl::marray +h2rsqrt(sycl::marray b) { + sycl::marray res; + res[0] = hrsqrt(b[0]); + res[1] = hrsqrt(b[1]); + return res; +} + #endif } // namespace math } // namespace intel