From 4543d3ff9ccf48fd5bbe8aa95156f4764d556021 Mon Sep 17 00:00:00 2001 From: Alexey Bader Date: Fri, 28 Feb 2025 19:41:27 -0800 Subject: [PATCH 1/7] [SYCL][bfloat16] Simplify bfloat16 class This change makes a few changes aiming to make the code more readable. * Move declaration of the bfloat16 class right on top of the header file and move all declarations with the implementation helpers after the class declaration. * Converted ConvertToBfloat16 class to namespace. The class doesn't represent any object, it's a collection of static methods. * Declare public API of the class first. * Define bfloat16 class as final. * Move Bfloat16StorageT declaration from the detail namespace directly into bfloat16 class definition. * Co-locate declaration of functions with similar functionality (e.g. conversion functions). * Simplified uniry operator- implementation by removing additional branch for the SPIR target. Generic code for SPIR target produces exactly the same code. * Added compile time checks for the array size inside conversion helper functions. Device side implementation of these functions work correctly only for certain array sizes. * Include C++ header instead of C (i.e. cstdint instead of stdint.h). * Applied "no-else-after-return" and "no-braces-for-single-statement-if" coding style changes. --- .../sycl/detail/generic_type_traits.hpp | 2 +- sycl/include/sycl/ext/oneapi/bfloat16.hpp | 533 +++++++++--------- sycl/include/sycl/vector.hpp | 6 +- .../vector/vector_math_ops.cpp | 42 +- .../vector/vector_math_ops_preview.cpp | 4 +- 5 files changed, 296 insertions(+), 291 deletions(-) diff --git a/sycl/include/sycl/detail/generic_type_traits.hpp b/sycl/include/sycl/detail/generic_type_traits.hpp index 7670164f820a8..984a05ace06be 100644 --- a/sycl/include/sycl/detail/generic_type_traits.hpp +++ b/sycl/include/sycl/detail/generic_type_traits.hpp @@ -172,7 +172,7 @@ template auto convertToOpenCLType(T &&x) { } else if constexpr (std::is_same_v) { // On host, don't interpret BF16 as uint16. #ifdef __SYCL_DEVICE_ONLY__ - using OpenCLType = sycl::ext::oneapi::detail::Bfloat16StorageT; + using OpenCLType = sycl::ext::oneapi::bfloat16::Bfloat16StorageT; return sycl::bit_cast(x); #else return std::forward(x); diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 77baf10d86f0b..c6f2575c33091 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -13,139 +13,22 @@ #include // for __DPCPP_SYCL_EXTERNAL #include // for half -#include // for uint16_t, uint32_t - -extern "C" __DPCPP_SYCL_EXTERNAL uint16_t -__devicelib_ConvertFToBF16INTEL(const float &) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL float -__devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertFToBF16INTELVec1(const float *, uint16_t *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertBF16ToFINTELVec1(const uint16_t *, float *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertFToBF16INTELVec2(const float *, uint16_t *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertBF16ToFINTELVec2(const uint16_t *, float *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertFToBF16INTELVec3(const float *, uint16_t *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertBF16ToFINTELVec3(const uint16_t *, float *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertFToBF16INTELVec4(const float *, uint16_t *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertBF16ToFINTELVec4(const uint16_t *, float *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertFToBF16INTELVec8(const float *, uint16_t *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertBF16ToFINTELVec8(const uint16_t *, float *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertFToBF16INTELVec16(const float *, uint16_t *) noexcept; -extern "C" __DPCPP_SYCL_EXTERNAL void -__devicelib_ConvertBF16ToFINTELVec16(const uint16_t *, float *) noexcept; +#include // for uint16_t, uint32_t namespace sycl { inline namespace _V1 { namespace ext::oneapi { -class bfloat16; - -namespace detail { -using Bfloat16StorageT = uint16_t; - -template void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) - const uint16_t *src_i16 = sycl::bit_cast(src); - if constexpr (N == 1) - __devicelib_ConvertBF16ToFINTELVec1(src_i16, dst); - else if constexpr (N == 2) - __devicelib_ConvertBF16ToFINTELVec2(src_i16, dst); - else if constexpr (N == 3) - __devicelib_ConvertBF16ToFINTELVec3(src_i16, dst); - else if constexpr (N == 4) - __devicelib_ConvertBF16ToFINTELVec4(src_i16, dst); - else if constexpr (N == 8) - __devicelib_ConvertBF16ToFINTELVec8(src_i16, dst); - else if constexpr (N == 16) - __devicelib_ConvertBF16ToFINTELVec16(src_i16, dst); -#else - for (int i = 0; i < N; ++i) { - dst[i] = (float)src[i]; - } -#endif -} -} // namespace detail - -class bfloat16 { -protected: - detail::Bfloat16StorageT value; - +class bfloat16 final { public: + using Bfloat16StorageT = uint16_t; + bfloat16() = default; ~bfloat16() = default; constexpr bfloat16(const bfloat16 &) = default; constexpr bfloat16(bfloat16 &&) = default; constexpr bfloat16 &operator=(const bfloat16 &rhs) = default; -private: - static detail::Bfloat16StorageT from_float_fallback(const float &a) { - // We don't call sycl::isnan because we don't want a data type to depend on - // builtins. - if (a != a) - return 0xffc1; - - union { - uint32_t intStorage; - float floatValue; - }; - floatValue = a; - // Do RNE and truncate - uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF; - return static_cast((intStorage + roundingBias) >> 16); - } - - // Explicit conversion functions - static detail::Bfloat16StorageT from_float(const float &a) { -#if defined(__SYCL_DEVICE_ONLY__) -#if defined(__NVPTX__) -#if (__SYCL_CUDA_ARCH__ >= 800) - detail::Bfloat16StorageT res; - asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(res) : "f"(a)); - return res; -#else - return from_float_fallback(a); -#endif -#elif defined(__AMDGCN__) - return from_float_fallback(a); -#else - return __devicelib_ConvertFToBF16INTEL(a); -#endif -#endif - return from_float_fallback(a); - } - - static float to_float(const detail::Bfloat16StorageT &a) { -#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) - return __devicelib_ConvertBF16ToFINTEL(a); -#else - union { - uint32_t intStorage; - float floatValue; - }; - intStorage = a << 16; - return floatValue; -#endif - } - -protected: - friend class sycl::vec; - friend class sycl::vec; - friend class sycl::vec; - friend class sycl::vec; - friend class sycl::vec; - friend class sycl::vec; - -public: // Implicit conversion from float to bfloat16 bfloat16(const float &a) { value = from_float(a); } @@ -175,11 +58,9 @@ class bfloat16 { friend bfloat16 operator-(const bfloat16 &lhs) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) && \ (__SYCL_CUDA_ARCH__ >= 800) - detail::Bfloat16StorageT res; + Bfloat16StorageT res; asm("neg.bf16 %0, %1;" : "=h"(res) : "h"(lhs.value)); return bit_cast(res); -#elif defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) - return bfloat16{-__devicelib_ConvertBF16ToFINTEL(lhs.value)}; #else return bfloat16{-to_float(lhs.value)}; #endif @@ -256,11 +137,144 @@ class bfloat16 { rhs = ValFloat; return I; } + +private: + Bfloat16StorageT value; + + // Explicit conversion functions + static float to_float(const Bfloat16StorageT &a); + static Bfloat16StorageT from_float(const float &a); + + // Friend classes for vector operations + friend class sycl::vec; + friend class sycl::vec; + friend class sycl::vec; + friend class sycl::vec; + friend class sycl::vec; + friend class sycl::vec; }; +// Helper functions for conversions between bfloat16 and float scalar types. +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +extern "C" __DPCPP_SYCL_EXTERNAL float +__devicelib_ConvertBF16ToFINTEL(const uint16_t &) noexcept; +#endif +inline float bfloat16::to_float(const bfloat16::Bfloat16StorageT &a) { +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) + return __devicelib_ConvertBF16ToFINTEL(a); +#else + union { + uint32_t intStorage; + float floatValue; + }; + intStorage = a << 16; + return floatValue; +#endif +} + +namespace detail { +inline uint16_t from_float_to_uint16_t(const float &a) { + // We don't call sycl::isnan because we don't want a data type to depend on + // builtins. + if (a != a) + return 0xffc1; + + union { + uint32_t intStorage; + float floatValue; + }; + floatValue = a; + // Do RNE and truncate + uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF; + return static_cast((intStorage + roundingBias) >> 16); +} +} // namespace detail + +#if defined(__SYCL_DEVICE_ONLY__) +extern "C" __DPCPP_SYCL_EXTERNAL uint16_t +__devicelib_ConvertFToBF16INTEL(const float &) noexcept; +#endif +inline bfloat16::Bfloat16StorageT bfloat16::from_float(const float &a) { +#if defined(__SYCL_DEVICE_ONLY__) +#if defined(__NVPTX__) +#if (__SYCL_CUDA_ARCH__ >= 800) + Bfloat16StorageT res; + asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(res) : "f"(a)); + return res; +#else + return detail::from_float_to_uint16_t(a); +#endif +#elif defined(__AMDGCN__) + return detail::from_float_to_uint16_t(a); +#else + return __devicelib_ConvertFToBF16INTEL(a); +#endif +#endif + return detail::from_float_to_uint16_t(a); +} + namespace detail { +// Conversion functions for bfloat16 + +// Helper functions for vector conversions from bfloat16 to float +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertBF16ToFINTELVec1(const uint16_t *, float *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertBF16ToFINTELVec2(const uint16_t *, float *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertBF16ToFINTELVec3(const uint16_t *, float *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertBF16ToFINTELVec4(const uint16_t *, float *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertBF16ToFINTELVec8(const uint16_t *, float *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertBF16ToFINTELVec16(const uint16_t *, float *) noexcept; +#endif + +template void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) { + static_assert(N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16, + "Unsupported vector size"); +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) + const uint16_t *src_i16 = sycl::bit_cast(src); + if constexpr (N == 1) + __devicelib_ConvertBF16ToFINTELVec1(src_i16, dst); + else if constexpr (N == 2) + __devicelib_ConvertBF16ToFINTELVec2(src_i16, dst); + else if constexpr (N == 3) + __devicelib_ConvertBF16ToFINTELVec3(src_i16, dst); + else if constexpr (N == 4) + __devicelib_ConvertBF16ToFINTELVec4(src_i16, dst); + else if constexpr (N == 8) + __devicelib_ConvertBF16ToFINTELVec8(src_i16, dst); + else if constexpr (N == 16) + __devicelib_ConvertBF16ToFINTELVec16(src_i16, dst); +#else + for (int i = 0; i < N; ++i) { + dst[i] = (float)src[i]; + } +#endif +} + +// Helper functions for vector conversions from float to bfloat16 +#if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertFToBF16INTELVec1(const float *, uint16_t *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertFToBF16INTELVec2(const float *, uint16_t *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertFToBF16INTELVec3(const float *, uint16_t *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertFToBF16INTELVec4(const float *, uint16_t *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertFToBF16INTELVec8(const float *, uint16_t *) noexcept; +extern "C" __DPCPP_SYCL_EXTERNAL void +__devicelib_ConvertFToBF16INTELVec16(const float *, uint16_t *) noexcept; +#endif template void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) { + static_assert(N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16, + "Unsupported vector size"); #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) uint16_t *dst_i16 = sycl::bit_cast(dst); if constexpr (N == 1) @@ -284,31 +298,16 @@ template void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) { #endif } -// Class to convert different data types to Bfloat16 -// with different rounding modes. -class ConvertToBfloat16 { - +// Conversion functions from different data types to Bfloat16 with different +// rounding modes. +namespace ConvertToBfloat16 { // The automatic rounding mode is RTE. enum SYCLRoundingMode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 }; - // Function to get the most significant bit position of a number. - template static size_t get_msb_pos(const Ty &x) { - assert(x != 0); - size_t idx = 0; - Ty mask = ((Ty)1 << (sizeof(Ty) * 8 - 1)); - for (idx = 0; idx < (sizeof(Ty) * 8); ++idx) { - if ((x & mask) == mask) - break; - mask >>= 1; - } - - return (sizeof(Ty) * 8 - 1 - idx); - } - // Helper function to get BF16 from float with different rounding modes. // Reference: // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L30 - static bfloat16 + inline bfloat16 getBFloat16FromFloatWithRoundingMode(const float &f, SYCLRoundingMode roundingMode) { @@ -316,55 +315,138 @@ class ConvertToBfloat16 { roundingMode == SYCLRoundingMode::rte) { // Use the default rounding mode. return bfloat16{f}; - } else { - uint32_t u32_val = sycl::bit_cast(f); - uint16_t bf16_sign = static_cast((u32_val >> 31) & 0x1); - uint16_t bf16_exp = static_cast((u32_val >> 23) & 0x7FF); - uint32_t f_mant = u32_val & 0x7F'FFFF; - uint16_t bf16_mant = static_cast(f_mant >> 16); - // +/-infinity and NAN - if (bf16_exp == 0xFF) { - if (!f_mant) - return bit_cast(bf16_sign ? 0xFF80 : 0x7F80); - else - return bit_cast((bf16_sign << 15) | - (bf16_exp << 7) | bf16_mant); - } + } + uint32_t u32_val = sycl::bit_cast(f); + uint16_t bf16_sign = static_cast((u32_val >> 31) & 0x1); + uint16_t bf16_exp = static_cast((u32_val >> 23) & 0x7FF); + uint32_t f_mant = u32_val & 0x7F'FFFF; + uint16_t bf16_mant = static_cast(f_mant >> 16); + // +/-infinity and NAN + if (bf16_exp == 0xFF) { + if (!f_mant) + return bit_cast(bf16_sign ? 0xFF80 : 0x7F80); + return bit_cast((bf16_sign << 15) | (bf16_exp << 7) | + bf16_mant); + } - // +/-0 - if (!bf16_exp && !f_mant) { - return bit_cast(bf16_sign ? 0x8000 : 0x0); - } + // +/-0 + if (!bf16_exp && !f_mant) { + return bit_cast(bf16_sign ? 0x8000 : 0x0); + } - uint16_t mant_discard = static_cast(f_mant & 0xFFFF); - switch (roundingMode) { - case SYCLRoundingMode::rtn: - if (bf16_sign && mant_discard) - bf16_mant++; - break; - case SYCLRoundingMode::rtz: - break; - case SYCLRoundingMode::rtp: - if (!bf16_sign && mant_discard) - bf16_mant++; - break; + uint16_t mant_discard = static_cast(f_mant & 0xFFFF); + switch (roundingMode) { + case SYCLRoundingMode::rtn: + if (bf16_sign && mant_discard) + bf16_mant++; + break; + case SYCLRoundingMode::rtz: + break; + case SYCLRoundingMode::rtp: + if (!bf16_sign && mant_discard) + bf16_mant++; + break; + + // Should not reach here. Adding these just to suppress the warning. + case SYCLRoundingMode::automatic: + case SYCLRoundingMode::rte: + break; + } - // Should not reach here. Adding these just to suppress the warning. - case SYCLRoundingMode::automatic: - case SYCLRoundingMode::rte: - break; - } + // if overflow happens, bf16_exp will be 0xFF and bf16_mant will be 0, + // infinity will be returned. + if (bf16_mant == 0x80) { + bf16_mant = 0; + bf16_exp++; + } + + return bit_cast((bf16_sign << 15) | (bf16_exp << 7) | + bf16_mant); + } - // if overflow happens, bf16_exp will be 0xFF and bf16_mant will be 0, - // infinity will be returned. + // Helper function to get BF16 from double with RTE rounding modes. + // Reference: + // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L79 + inline bfloat16 getBFloat16FromDoubleWithRTE(const double &d) { + + uint64_t u64_val = sycl::bit_cast(d); + int16_t bf16_sign = (u64_val >> 63) & 0x1; + uint16_t fp64_exp = static_cast((u64_val >> 52) & 0x7FF); + uint64_t fp64_mant = (u64_val & 0xF'FFFF'FFFF'FFFF); + uint16_t bf16_mant; + // handling +/-infinity and NAN for double input + if (fp64_exp == 0x7FF) { + if (!fp64_mant) + return bf16_sign ? 0xFF80 : 0x7F80; + + // returns a quiet NaN + return 0x7FC0; + } + + // Subnormal double precision is converted to 0 + if (fp64_exp == 0) + return bf16_sign ? 0x8000 : 0x0; + + fp64_exp -= 1023; + + // handling overflow, convert to +/-infinity + if (static_cast(fp64_exp) > 127) + return bf16_sign ? 0xFF80 : 0x7F80; + + // handling underflow + if (static_cast(fp64_exp) < -133) + return bf16_sign ? 0x8000 : 0x0; + + //-133 <= fp64_exp <= 127, 1.signicand * 2^fp64_exp + // For these numbers, they are NOT subnormal double-precision numbers but + // will turn into subnormal when converting to bfloat16 + uint64_t discard_bits; + if (static_cast(fp64_exp) < -126) { + fp64_mant |= 0x10'0000'0000'0000; + fp64_mant >>= -126 - static_cast(fp64_exp) - 1; + discard_bits = fp64_mant & 0x3FFF'FFFF'FFFF; + bf16_mant = static_cast(fp64_mant >> 46); + if (discard_bits > 0x2000'0000'0000 || + ((discard_bits == 0x2000'0000'0000) && ((bf16_mant & 0x1) == 0x1))) + bf16_mant += 1; + fp64_exp = 0; if (bf16_mant == 0x80) { bf16_mant = 0; - bf16_exp++; + fp64_exp = 1; } + return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant; + } - return bit_cast((bf16_sign << 15) | (bf16_exp << 7) | - bf16_mant); + // For normal value, discard 45 bits from mantissa + discard_bits = fp64_mant & 0x1FFF'FFFF'FFFF; + bf16_mant = static_cast(fp64_mant >> 45); + if (discard_bits > 0x1000'0000'0000 || + ((discard_bits == 0x1000'0000'0000) && ((bf16_mant & 0x1) == 0x1))) + bf16_mant += 1; + + if (bf16_mant == 0x80) { + if (fp64_exp == 127) + return bf16_sign ? 0xFF80 : 0x7F80; + bf16_mant = 0; + fp64_exp++; + } + fp64_exp += 127; + + return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant; + } + + // Function to get the most significant bit position of a number. + template size_t get_msb_pos(const Ty &x) { + assert(x != 0); + size_t idx = 0; + Ty mask = ((Ty)1 << (sizeof(Ty) * 8 - 1)); + for (idx = 0; idx < (sizeof(Ty) * 8); ++idx) { + if ((x & mask) == mask) + break; + mask >>= 1; } + + return (sizeof(Ty) * 8 - 1 - idx); } // Helper function to get BF16 from unsigned integral data types @@ -372,7 +454,7 @@ class ConvertToBfloat16 { // Reference: // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L302 template - static bfloat16 + bfloat16 getBFloat16FromUIntegralWithRoundingMode(T &u, SYCLRoundingMode roundingMode) { @@ -427,7 +509,7 @@ class ConvertToBfloat16 { // Reference: // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L353 template - static bfloat16 + bfloat16 getBFloat16FromSIntegralWithRoundingMode(T &i, SYCLRoundingMode roundingMode) { // Get unsigned type corresponding to T. @@ -476,85 +558,8 @@ class ConvertToBfloat16 { return bit_cast(b_sign | (b_exp << 7) | b_mant); } - // Helper function to get BF16 from double with RTE rounding modes. - // Reference: - // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L79 - static bfloat16 getBFloat16FromDoubleWithRTE(const double &d) { - - uint64_t u64_val = sycl::bit_cast(d); - int16_t bf16_sign = (u64_val >> 63) & 0x1; - uint16_t fp64_exp = static_cast((u64_val >> 52) & 0x7FF); - uint64_t fp64_mant = (u64_val & 0xF'FFFF'FFFF'FFFF); - uint16_t bf16_mant; - // handling +/-infinity and NAN for double input - if (fp64_exp == 0x7FF) { - if (!fp64_mant) { - return bf16_sign ? 0xFF80 : 0x7F80; - } else { - // returns a quiet NaN - return 0x7FC0; - } - } - - // Subnormal double precision is converted to 0 - if (fp64_exp == 0) { - return bf16_sign ? 0x8000 : 0x0; - } - - fp64_exp -= 1023; - // handling overflow, convert to +/-infinity - if (static_cast(fp64_exp) > 127) { - return bf16_sign ? 0xFF80 : 0x7F80; - } - - // handling underflow - if (static_cast(fp64_exp) < -133) { - return bf16_sign ? 0x8000 : 0x0; - } - - //-133 <= fp64_exp <= 127, 1.signicand * 2^fp64_exp - // For these numbers, they are NOT subnormal double-precision numbers but - // will turn into subnormal when converting to bfloat16 - uint64_t discard_bits; - if (static_cast(fp64_exp) < -126) { - fp64_mant |= 0x10'0000'0000'0000; - fp64_mant >>= -126 - static_cast(fp64_exp) - 1; - discard_bits = fp64_mant & 0x3FFF'FFFF'FFFF; - bf16_mant = static_cast(fp64_mant >> 46); - if (discard_bits > 0x2000'0000'0000 || - ((discard_bits == 0x2000'0000'0000) && ((bf16_mant & 0x1) == 0x1))) - bf16_mant += 1; - fp64_exp = 0; - if (bf16_mant == 0x80) { - bf16_mant = 0; - fp64_exp = 1; - } - return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant; - } - - // For normal value, discard 45 bits from mantissa - discard_bits = fp64_mant & 0x1FFF'FFFF'FFFF; - bf16_mant = static_cast(fp64_mant >> 45); - if (discard_bits > 0x1000'0000'0000 || - ((discard_bits == 0x1000'0000'0000) && ((bf16_mant & 0x1) == 0x1))) - bf16_mant += 1; - - if (bf16_mant == 0x80) { - if (fp64_exp != 127) { - bf16_mant = 0; - fp64_exp++; - } else { - return bf16_sign ? 0xFF80 : 0x7F80; - } - } - fp64_exp += 127; - - return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant; - } - -public: template - static bfloat16 getBfloat16WithRoundingMode(const Ty &a) { + bfloat16 getBfloat16WithRoundingMode(const Ty &a) { if (!a) return bfloat16{0.0f}; @@ -593,7 +598,7 @@ class ConvertToBfloat16 { "Only integral and floating point types are supported."); } } -}; // class ConvertToBfloat16. +} // namespace ConvertToBfloat16 } // namespace detail } // namespace ext::oneapi diff --git a/sycl/include/sycl/vector.hpp b/sycl/include/sycl/vector.hpp index 015f345740ea4..e8e2056191d03 100644 --- a/sycl/include/sycl/vector.hpp +++ b/sycl/include/sycl/vector.hpp @@ -274,9 +274,9 @@ class __SYCL_EBO vec bool, /*->*/ std::uint8_t, // sycl::half, /*->*/ sycl::detail::half_impl::StorageT, // sycl::ext::oneapi::bfloat16, - /*->*/ sycl::ext::oneapi::detail::Bfloat16StorageT, // - char, /*->*/ detail::ConvertToOpenCLType_t, // - DataT, /*->*/ DataT // + /*->*/ sycl::ext::oneapi::bfloat16::Bfloat16StorageT, // + char, /*->*/ detail::ConvertToOpenCLType_t, // + DataT, /*->*/ DataT // >::type; public: diff --git a/sycl/test/check_device_code/vector/vector_math_ops.cpp b/sycl/test/check_device_code/vector/vector_math_ops.cpp index 985aac0d084d7..f06fcf61e23bf 100644 --- a/sycl/test/check_device_code/vector/vector_math_ops.cpp +++ b/sycl/test/check_device_code/vector/vector_math_ops.cpp @@ -121,17 +121,17 @@ SYCL_EXTERNAL auto TestAdd(vec a, vec b) { return a + b; } // CHECK-NEXT: [[CMP_I_I:%.*]] = icmp samesign ult i64 [[I_0_I_I]], 3 // CHECK-NEXT: br i1 [[CMP_I_I]], label [[FOR_BODY_I_I]], label [[_ZN4SYCL3_V16DETAILPLINS0_3EXT6ONEAPI8BFLOAT16EEENST9ENABLE_IFIX24IS_OP_AVAILABLE_FOR_TYPEIST4PLUSIVET_EENS0_3VECIS5_LI3EEEE4TYPEERKSB_SF__EXIT:%.*]] // CHECK: for.body.i.i: -// CHECK-NEXT: [[ARRAYIDX_I_I_I_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[A_ASCAST]], i64 0, i64 [[I_0_I_I]] -// CHECK-NEXT: [[ARRAYIDX_I_I_I12_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[B_ASCAST]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: [[ARRAYIDX_I_I_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[A_ASCAST]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: [[ARRAYIDX_I_I12_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[B_ASCAST]], i64 0, i64 [[I_0_I_I]] // CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[REF_TMP_I_I_I_I]]), !noalias [[META80:![0-9]+]] -// CHECK-NEXT: [[CALL_I_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I_I_I_I]]) #[[ATTR8:[0-9]+]], !noalias [[META83:![0-9]+]] -// CHECK-NEXT: [[CALL_I_I2_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I_I12_I_I]]) #[[ATTR8]], !noalias [[META83]] +// CHECK-NEXT: [[CALL_I_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I_I_I]]) #[[ATTR8:[0-9]+]], !noalias [[META83:![0-9]+]] +// CHECK-NEXT: [[CALL_I_I2_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I12_I_I]]) #[[ATTR8]], !noalias [[META83]] // CHECK-NEXT: [[ADD_I_I_I_I:%.*]] = fadd float [[CALL_I_I_I_I_I_I]], [[CALL_I_I2_I_I_I_I]] // CHECK-NEXT: store float [[ADD_I_I_I_I]], ptr [[REF_TMP_I_I_I_I]], align 4, !tbaa [[TBAA86:![0-9]+]], !noalias [[META83]] // CHECK-NEXT: [[CALL_I_I3_I_I_I_I:%.*]] = call spir_func noundef zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4) [[REF_TMP_ASCAST_I_I_I_I]]) #[[ATTR8]], !noalias [[META83]] // CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[REF_TMP_I_I_I_I]]), !noalias [[META80]] -// CHECK-NEXT: [[ARRAYIDX_I_I_I14_I_I:%.*]] = getelementptr inbounds [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr [[RES_I_I]], i64 0, i64 [[I_0_I_I]] -// CHECK-NEXT: store i16 [[CALL_I_I3_I_I_I_I]], ptr [[ARRAYIDX_I_I_I14_I_I]], align 2, !tbaa [[TBAA88:![0-9]+]], !noalias [[META79]] +// CHECK-NEXT: [[ARRAYIDX_I_I14_I_I:%.*]] = getelementptr inbounds [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr [[RES_I_I]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: store i16 [[CALL_I_I3_I_I_I_I]], ptr [[ARRAYIDX_I_I14_I_I]], align 2, !tbaa [[TBAA88:![0-9]+]], !noalias [[META79]] // CHECK-NEXT: [[INC_I_I]] = add nuw nsw i64 [[I_0_I_I]], 1 // CHECK-NEXT: br label [[FOR_COND_I_I]], !llvm.loop [[LOOP90:![0-9]+]] // CHECK: _ZN4sycl3_V16detailplINS0_3ext6oneapi8bfloat16EEENSt9enable_ifIX24is_op_available_for_typeISt4plusIvET_EENS0_3vecIS5_Li3EEEE4typeERKSB_SF_.exit: @@ -225,14 +225,14 @@ SYCL_EXTERNAL auto TestGreaterThan(vec a, vec b) { // CHECK-NEXT: [[CMP_I_I:%.*]] = icmp samesign ult i64 [[I_0_I_I]], 4 // CHECK-NEXT: br i1 [[CMP_I_I]], label [[FOR_BODY_I_I]], label [[_ZN4SYCL3_V16DETAILGTINS0_3EXT6ONEAPI8BFLOAT16EEENST9ENABLE_IFIX24IS_OP_AVAILABLE_FOR_TYPEIST7GREATERIVET_EENS0_3VECISLI4EEEE4TYPEERKNSA_IS5_LI4EEESG__EXIT:%.*]] // CHECK: for.body.i.i: -// CHECK-NEXT: [[ARRAYIDX_I_I_I_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[A_ASCAST]], i64 0, i64 [[I_0_I_I]] -// CHECK-NEXT: [[ARRAYIDX_I_I_I14_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[B_ASCAST]], i64 0, i64 [[I_0_I_I]] -// CHECK-NEXT: [[CALL_I_I_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I_I_I_I]]) #[[ATTR8]], !noalias [[META127]] -// CHECK-NEXT: [[CALL_I_I2_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I_I14_I_I]]) #[[ATTR8]], !noalias [[META127]] +// CHECK-NEXT: [[ARRAYIDX_I_I_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[A_ASCAST]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: [[ARRAYIDX_I_I14_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[B_ASCAST]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: [[CALL_I_I_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I_I_I]]) #[[ATTR8]], !noalias [[META127]] +// CHECK-NEXT: [[CALL_I_I2_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I14_I_I]]) #[[ATTR8]], !noalias [[META127]] // CHECK-NEXT: [[CMP_I_I_I_I_I:%.*]] = fcmp ogt float [[CALL_I_I_I_I_I_I_I]], [[CALL_I_I2_I_I_I_I_I]] // CHECK-NEXT: [[CONV6_I_I:%.*]] = sext i1 [[CMP_I_I_I_I_I]] to i16 -// CHECK-NEXT: [[ARRAYIDX_I_I_I16_I_I:%.*]] = getelementptr inbounds [4 x i16], ptr [[RES_I_I]], i64 0, i64 [[I_0_I_I]] -// CHECK-NEXT: store i16 [[CONV6_I_I]], ptr [[ARRAYIDX_I_I_I16_I_I]], align 2, !tbaa [[TBAA88]], !noalias [[META127]] +// CHECK-NEXT: [[ARRAYIDX_I_I16_I_I:%.*]] = getelementptr inbounds [4 x i16], ptr [[RES_I_I]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: store i16 [[CONV6_I_I]], ptr [[ARRAYIDX_I_I16_I_I]], align 2, !tbaa [[TBAA88]], !noalias [[META127]] // CHECK-NEXT: [[INC_I_I]] = add nuw nsw i64 [[I_0_I_I]], 1 // CHECK-NEXT: br label [[FOR_COND_I_I]], !llvm.loop [[LOOP128:![0-9]+]] // CHECK: _ZN4sycl3_V16detailgtINS0_3ext6oneapi8bfloat16EEENSt9enable_ifIX24is_op_available_for_typeISt7greaterIvET_EENS0_3vecIsLi4EEEE4typeERKNSA_IS5_Li4EEESG_.exit: @@ -341,12 +341,12 @@ SYCL_EXTERNAL auto TestMinus(vec a) { return -a; } // CHECK-NEXT: [[CMP_I_I:%.*]] = icmp samesign ult i64 [[I_0_I_I]], 3 // CHECK-NEXT: br i1 [[CMP_I_I]], label [[FOR_BODY_I_I]], label [[_ZN4SYCL3_V16DETAILNTERKNS0_3VECINS0_3EXT6ONEAPI8BFLOAT16ELI3EEE_EXIT:%.*]] // CHECK: for.body.i.i: -// CHECK-NEXT: [[ARRAYIDX_I_I_I_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[A_ASCAST]], i64 0, i64 [[I_0_I_I]] -// CHECK-NEXT: [[CALL_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I_I_I_I]]) #[[ATTR8]], !noalias [[META190]] +// CHECK-NEXT: [[ARRAYIDX_I_I_I_I:%.*]] = getelementptr inbounds nuw [4 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[A_ASCAST]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: [[CALL_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) dereferenceable_or_null(2) [[ARRAYIDX_I_I_I_I]]) #[[ATTR8]], !noalias [[META190]] // CHECK-NEXT: [[TOBOOL_I_I_I:%.*]] = fcmp oeq float [[CALL_I_I_I_I_I]], 0.000000e+00 // CHECK-NEXT: [[CONV2_I_I:%.*]] = sext i1 [[TOBOOL_I_I_I]] to i16 -// CHECK-NEXT: [[ARRAYIDX_I_I_I9_I_I:%.*]] = getelementptr inbounds [4 x i16], ptr [[RES_I_I]], i64 0, i64 [[I_0_I_I]] -// CHECK-NEXT: store i16 [[CONV2_I_I]], ptr [[ARRAYIDX_I_I_I9_I_I]], align 2, !tbaa [[TBAA88]], !noalias [[META190]] +// CHECK-NEXT: [[ARRAYIDX_I_I9_I_I:%.*]] = getelementptr inbounds [4 x i16], ptr [[RES_I_I]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: store i16 [[CONV2_I_I]], ptr [[ARRAYIDX_I_I9_I_I]], align 2, !tbaa [[TBAA88]], !noalias [[META190]] // CHECK-NEXT: [[INC_I_I]] = add nuw nsw i64 [[I_0_I_I]], 1 // CHECK-NEXT: br label [[FOR_COND_I_I]], !llvm.loop [[LOOP191:![0-9]+]] // CHECK: _ZN4sycl3_V16detailntERKNS0_3vecINS0_3ext6oneapi8bfloat16ELi3EEE.exit: @@ -372,15 +372,15 @@ SYCL_EXTERNAL auto TestNegation(vec a) { return !a; } // CHECK-NEXT: [[CMP_I_I:%.*]] = icmp samesign ult i64 [[I_0_I_I]], 16 // CHECK-NEXT: br i1 [[CMP_I_I]], label [[FOR_BODY_I_I]], label [[_ZN4SYCL3_V16DETAILNGERKNS0_3VECINS0_3EXT6ONEAPI8BFLOAT16ELI16EEE_EXIT:%.*]] // CHECK: for.body.i.i: -// CHECK-NEXT: [[ARRAYIDX_I_I_I_I_I:%.*]] = getelementptr inbounds nuw [16 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[A_ASCAST]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: [[ARRAYIDX_I_I_I_I:%.*]] = getelementptr inbounds nuw [16 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[A_ASCAST]], i64 0, i64 [[I_0_I_I]] // CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[REF_TMP_I_I_I_I]]), !noalias [[META199:![0-9]+]] -// CHECK-NEXT: [[CALL_I_I_I_I:%.*]] = call spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) [[ARRAYIDX_I_I_I_I_I]]) #[[ATTR8]], !noalias [[META202:![0-9]+]] -// CHECK-NEXT: [[FNEG_I_I_I_I:%.*]] = fneg float [[CALL_I_I_I_I]] +// CHECK-NEXT: [[CALL_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) [[ARRAYIDX_I_I_I_I]]) #[[ATTR8]], !noalias [[META202:![0-9]+]] +// CHECK-NEXT: [[FNEG_I_I_I_I:%.*]] = fneg float [[CALL_I_I_I_I_I]] // CHECK-NEXT: store float [[FNEG_I_I_I_I]], ptr [[REF_TMP_I_I_I_I]], align 4, !tbaa [[TBAA86]], !noalias [[META202]] // CHECK-NEXT: [[CALL_I_I_I_I_I_I:%.*]] = call spir_func noundef zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4) [[REF_TMP_ASCAST_I_I_I_I]]) #[[ATTR8]], !noalias [[META202]] // CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[REF_TMP_I_I_I_I]]), !noalias [[META199]] -// CHECK-NEXT: [[ARRAYIDX_I_I_I7_I_I:%.*]] = getelementptr inbounds [16 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr [[RES_I_I]], i64 0, i64 [[I_0_I_I]] -// CHECK-NEXT: store i16 [[CALL_I_I_I_I_I_I]], ptr [[ARRAYIDX_I_I_I7_I_I]], align 2, !tbaa [[TBAA88]], !noalias [[META196]] +// CHECK-NEXT: [[ARRAYIDX_I_I7_I_I:%.*]] = getelementptr inbounds [16 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr [[RES_I_I]], i64 0, i64 [[I_0_I_I]] +// CHECK-NEXT: store i16 [[CALL_I_I_I_I_I_I]], ptr [[ARRAYIDX_I_I7_I_I]], align 2, !tbaa [[TBAA88]], !noalias [[META196]] // CHECK-NEXT: [[INC_I_I]] = add nuw nsw i64 [[I_0_I_I]], 1 // CHECK-NEXT: br label [[FOR_COND_I_I]], !llvm.loop [[LOOP205:![0-9]+]] // CHECK: _ZN4sycl3_V16detailngERKNS0_3vecINS0_3ext6oneapi8bfloat16ELi16EEE.exit: diff --git a/sycl/test/check_device_code/vector/vector_math_ops_preview.cpp b/sycl/test/check_device_code/vector/vector_math_ops_preview.cpp index b6adf26170ab7..2d5e2bbeb6f2d 100644 --- a/sycl/test/check_device_code/vector/vector_math_ops_preview.cpp +++ b/sycl/test/check_device_code/vector/vector_math_ops_preview.cpp @@ -374,8 +374,8 @@ SYCL_EXTERNAL auto TestNegation(vec a) { return !a; } // CHECK: for.body.i.i: // CHECK-NEXT: [[ARRAYIDX_I_I_I:%.*]] = getelementptr inbounds nuw [16 x %"class.sycl::_V1::ext::oneapi::bfloat16"], ptr addrspace(4) [[A_ASCAST]], i64 0, i64 [[I_0_I_I]] // CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[REF_TMP_I_I_I_I]]), !noalias [[META199:![0-9]+]] -// CHECK-NEXT: [[CALL_I_I_I_I:%.*]] = call spir_func float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) [[ARRAYIDX_I_I_I]]) #[[ATTR8]], !noalias [[META202:![0-9]+]] -// CHECK-NEXT: [[FNEG_I_I_I_I:%.*]] = fneg float [[CALL_I_I_I_I]] +// CHECK-NEXT: [[CALL_I_I_I_I_I:%.*]] = call spir_func noundef float @__devicelib_ConvertBF16ToFINTEL(ptr addrspace(4) noundef align 2 dereferenceable(2) [[ARRAYIDX_I_I_I]]) #[[ATTR8]], !noalias [[META202:![0-9]+]] +// CHECK-NEXT: [[FNEG_I_I_I_I:%.*]] = fneg float [[CALL_I_I_I_I_I]] // CHECK-NEXT: store float [[FNEG_I_I_I_I]], ptr [[REF_TMP_I_I_I_I]], align 4, !tbaa [[TBAA86]], !noalias [[META202]] // CHECK-NEXT: [[CALL_I_I_I_I_I_I:%.*]] = call spir_func noundef zeroext i16 @__devicelib_ConvertFToBF16INTEL(ptr addrspace(4) noundef align 4 dereferenceable(4) [[REF_TMP_ASCAST_I_I_I_I]]) #[[ATTR8]], !noalias [[META202]] // CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[REF_TMP_I_I_I_I]]), !noalias [[META199]] From 02d6322ccb06351e98296014ca6abafb52de1094 Mon Sep 17 00:00:00 2001 From: Alexey Bader Date: Fri, 28 Feb 2025 20:41:44 -0800 Subject: [PATCH 2/7] clang-format --- sycl/include/sycl/ext/oneapi/bfloat16.hpp | 517 +++++++++++----------- 1 file changed, 257 insertions(+), 260 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index c6f2575c33091..56baaf7b5e7d4 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -301,307 +301,304 @@ template void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) { // Conversion functions from different data types to Bfloat16 with different // rounding modes. namespace ConvertToBfloat16 { - // The automatic rounding mode is RTE. - enum SYCLRoundingMode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 }; - - // Helper function to get BF16 from float with different rounding modes. - // Reference: - // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L30 - inline bfloat16 - getBFloat16FromFloatWithRoundingMode(const float &f, - SYCLRoundingMode roundingMode) { - - if (roundingMode == SYCLRoundingMode::automatic || - roundingMode == SYCLRoundingMode::rte) { - // Use the default rounding mode. - return bfloat16{f}; - } - uint32_t u32_val = sycl::bit_cast(f); - uint16_t bf16_sign = static_cast((u32_val >> 31) & 0x1); - uint16_t bf16_exp = static_cast((u32_val >> 23) & 0x7FF); - uint32_t f_mant = u32_val & 0x7F'FFFF; - uint16_t bf16_mant = static_cast(f_mant >> 16); - // +/-infinity and NAN - if (bf16_exp == 0xFF) { - if (!f_mant) - return bit_cast(bf16_sign ? 0xFF80 : 0x7F80); - return bit_cast((bf16_sign << 15) | (bf16_exp << 7) | - bf16_mant); - } - - // +/-0 - if (!bf16_exp && !f_mant) { - return bit_cast(bf16_sign ? 0x8000 : 0x0); - } - - uint16_t mant_discard = static_cast(f_mant & 0xFFFF); - switch (roundingMode) { - case SYCLRoundingMode::rtn: - if (bf16_sign && mant_discard) - bf16_mant++; - break; - case SYCLRoundingMode::rtz: - break; - case SYCLRoundingMode::rtp: - if (!bf16_sign && mant_discard) - bf16_mant++; - break; - - // Should not reach here. Adding these just to suppress the warning. - case SYCLRoundingMode::automatic: - case SYCLRoundingMode::rte: - break; - } - - // if overflow happens, bf16_exp will be 0xFF and bf16_mant will be 0, - // infinity will be returned. - if (bf16_mant == 0x80) { - bf16_mant = 0; - bf16_exp++; - } - +// The automatic rounding mode is RTE. +enum SYCLRoundingMode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 }; + +// Helper function to get BF16 from float with different rounding modes. +// Reference: +// https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L30 +inline bfloat16 +getBFloat16FromFloatWithRoundingMode(const float &f, + SYCLRoundingMode roundingMode) { + + if (roundingMode == SYCLRoundingMode::automatic || + roundingMode == SYCLRoundingMode::rte) { + // Use the default rounding mode. + return bfloat16{f}; + } + uint32_t u32_val = sycl::bit_cast(f); + uint16_t bf16_sign = static_cast((u32_val >> 31) & 0x1); + uint16_t bf16_exp = static_cast((u32_val >> 23) & 0x7FF); + uint32_t f_mant = u32_val & 0x7F'FFFF; + uint16_t bf16_mant = static_cast(f_mant >> 16); + // +/-infinity and NAN + if (bf16_exp == 0xFF) { + if (!f_mant) + return bit_cast(bf16_sign ? 0xFF80 : 0x7F80); return bit_cast((bf16_sign << 15) | (bf16_exp << 7) | bf16_mant); } - // Helper function to get BF16 from double with RTE rounding modes. - // Reference: - // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L79 - inline bfloat16 getBFloat16FromDoubleWithRTE(const double &d) { - - uint64_t u64_val = sycl::bit_cast(d); - int16_t bf16_sign = (u64_val >> 63) & 0x1; - uint16_t fp64_exp = static_cast((u64_val >> 52) & 0x7FF); - uint64_t fp64_mant = (u64_val & 0xF'FFFF'FFFF'FFFF); - uint16_t bf16_mant; - // handling +/-infinity and NAN for double input - if (fp64_exp == 0x7FF) { - if (!fp64_mant) - return bf16_sign ? 0xFF80 : 0x7F80; - - // returns a quiet NaN - return 0x7FC0; - } + // +/-0 + if (!bf16_exp && !f_mant) { + return bit_cast(bf16_sign ? 0x8000 : 0x0); + } - // Subnormal double precision is converted to 0 - if (fp64_exp == 0) - return bf16_sign ? 0x8000 : 0x0; + uint16_t mant_discard = static_cast(f_mant & 0xFFFF); + switch (roundingMode) { + case SYCLRoundingMode::rtn: + if (bf16_sign && mant_discard) + bf16_mant++; + break; + case SYCLRoundingMode::rtz: + break; + case SYCLRoundingMode::rtp: + if (!bf16_sign && mant_discard) + bf16_mant++; + break; + + // Should not reach here. Adding these just to suppress the warning. + case SYCLRoundingMode::automatic: + case SYCLRoundingMode::rte: + break; + } + + // if overflow happens, bf16_exp will be 0xFF and bf16_mant will be 0, + // infinity will be returned. + if (bf16_mant == 0x80) { + bf16_mant = 0; + bf16_exp++; + } - fp64_exp -= 1023; + return bit_cast((bf16_sign << 15) | (bf16_exp << 7) | + bf16_mant); +} - // handling overflow, convert to +/-infinity - if (static_cast(fp64_exp) > 127) +// Helper function to get BF16 from double with RTE rounding modes. +// Reference: +// https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L79 +inline bfloat16 getBFloat16FromDoubleWithRTE(const double &d) { + + uint64_t u64_val = sycl::bit_cast(d); + int16_t bf16_sign = (u64_val >> 63) & 0x1; + uint16_t fp64_exp = static_cast((u64_val >> 52) & 0x7FF); + uint64_t fp64_mant = (u64_val & 0xF'FFFF'FFFF'FFFF); + uint16_t bf16_mant; + // handling +/-infinity and NAN for double input + if (fp64_exp == 0x7FF) { + if (!fp64_mant) return bf16_sign ? 0xFF80 : 0x7F80; - // handling underflow - if (static_cast(fp64_exp) < -133) - return bf16_sign ? 0x8000 : 0x0; - - //-133 <= fp64_exp <= 127, 1.signicand * 2^fp64_exp - // For these numbers, they are NOT subnormal double-precision numbers but - // will turn into subnormal when converting to bfloat16 - uint64_t discard_bits; - if (static_cast(fp64_exp) < -126) { - fp64_mant |= 0x10'0000'0000'0000; - fp64_mant >>= -126 - static_cast(fp64_exp) - 1; - discard_bits = fp64_mant & 0x3FFF'FFFF'FFFF; - bf16_mant = static_cast(fp64_mant >> 46); - if (discard_bits > 0x2000'0000'0000 || - ((discard_bits == 0x2000'0000'0000) && ((bf16_mant & 0x1) == 0x1))) - bf16_mant += 1; - fp64_exp = 0; - if (bf16_mant == 0x80) { - bf16_mant = 0; - fp64_exp = 1; - } - return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant; - } + // returns a quiet NaN + return 0x7FC0; + } - // For normal value, discard 45 bits from mantissa - discard_bits = fp64_mant & 0x1FFF'FFFF'FFFF; - bf16_mant = static_cast(fp64_mant >> 45); - if (discard_bits > 0x1000'0000'0000 || - ((discard_bits == 0x1000'0000'0000) && ((bf16_mant & 0x1) == 0x1))) + // Subnormal double precision is converted to 0 + if (fp64_exp == 0) + return bf16_sign ? 0x8000 : 0x0; + + fp64_exp -= 1023; + + // handling overflow, convert to +/-infinity + if (static_cast(fp64_exp) > 127) + return bf16_sign ? 0xFF80 : 0x7F80; + + // handling underflow + if (static_cast(fp64_exp) < -133) + return bf16_sign ? 0x8000 : 0x0; + + //-133 <= fp64_exp <= 127, 1.signicand * 2^fp64_exp + // For these numbers, they are NOT subnormal double-precision numbers but + // will turn into subnormal when converting to bfloat16 + uint64_t discard_bits; + if (static_cast(fp64_exp) < -126) { + fp64_mant |= 0x10'0000'0000'0000; + fp64_mant >>= -126 - static_cast(fp64_exp) - 1; + discard_bits = fp64_mant & 0x3FFF'FFFF'FFFF; + bf16_mant = static_cast(fp64_mant >> 46); + if (discard_bits > 0x2000'0000'0000 || + ((discard_bits == 0x2000'0000'0000) && ((bf16_mant & 0x1) == 0x1))) bf16_mant += 1; - + fp64_exp = 0; if (bf16_mant == 0x80) { - if (fp64_exp == 127) - return bf16_sign ? 0xFF80 : 0x7F80; bf16_mant = 0; - fp64_exp++; + fp64_exp = 1; } - fp64_exp += 127; - return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant; } - // Function to get the most significant bit position of a number. - template size_t get_msb_pos(const Ty &x) { - assert(x != 0); - size_t idx = 0; - Ty mask = ((Ty)1 << (sizeof(Ty) * 8 - 1)); - for (idx = 0; idx < (sizeof(Ty) * 8); ++idx) { - if ((x & mask) == mask) - break; - mask >>= 1; - } + // For normal value, discard 45 bits from mantissa + discard_bits = fp64_mant & 0x1FFF'FFFF'FFFF; + bf16_mant = static_cast(fp64_mant >> 45); + if (discard_bits > 0x1000'0000'0000 || + ((discard_bits == 0x1000'0000'0000) && ((bf16_mant & 0x1) == 0x1))) + bf16_mant += 1; + + if (bf16_mant == 0x80) { + if (fp64_exp == 127) + return bf16_sign ? 0xFF80 : 0x7F80; + bf16_mant = 0; + fp64_exp++; + } + fp64_exp += 127; + + return (bf16_sign << 15) | (fp64_exp << 7) | bf16_mant; +} - return (sizeof(Ty) * 8 - 1 - idx); +// Function to get the most significant bit position of a number. +template size_t get_msb_pos(const Ty &x) { + assert(x != 0); + size_t idx = 0; + Ty mask = ((Ty)1 << (sizeof(Ty) * 8 - 1)); + for (idx = 0; idx < (sizeof(Ty) * 8); ++idx) { + if ((x & mask) == mask) + break; + mask >>= 1; } + return (sizeof(Ty) * 8 - 1 - idx); +} + // Helper function to get BF16 from unsigned integral data types // with different rounding modes. // Reference: // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L302 - template - bfloat16 - getBFloat16FromUIntegralWithRoundingMode(T &u, - SYCLRoundingMode roundingMode) { - - size_t msb_pos = get_msb_pos(u); - // return half representation for 1 - if (msb_pos == 0) - return bit_cast(0x3F80); - - T mant = u & ((static_cast(1) << msb_pos) - 1); - // Unsigned integral value can be represented by 1.mant * (2^msb_pos), - // msb_pos is also the bit number of mantissa, 0 < msb_pos < sizeof(Ty) * 8, - // exponent of bfloat16 precision value range is [-126, 127]. - - uint16_t b_exp = msb_pos; - uint16_t b_mant; - - if (msb_pos <= 7) { - // No need to round off if we can losslessly fit the input value in - // mantissa of bfloat16. - mant <<= (7 - msb_pos); - b_mant = static_cast(mant); - } else { - b_mant = static_cast(mant >> (msb_pos - 7)); - T mant_discard = mant & ((static_cast(1) << (msb_pos - 7)) - 1); - T mid = static_cast(1) << (msb_pos - 8); - switch (roundingMode) { - case SYCLRoundingMode::automatic: - case SYCLRoundingMode::rte: - if ((mant_discard > mid) || - ((mant_discard == mid) && ((b_mant & 0x1) == 0x1))) - b_mant++; - break; - case SYCLRoundingMode::rtp: - if (mant_discard) - b_mant++; - break; - case SYCLRoundingMode::rtn: - case SYCLRoundingMode::rtz: - break; - } - } - if (b_mant == 0x80) { - b_exp++; - b_mant = 0; +template +bfloat16 +getBFloat16FromUIntegralWithRoundingMode(T &u, SYCLRoundingMode roundingMode) { + + size_t msb_pos = get_msb_pos(u); + // return half representation for 1 + if (msb_pos == 0) + return bit_cast(0x3F80); + + T mant = u & ((static_cast(1) << msb_pos) - 1); + // Unsigned integral value can be represented by 1.mant * (2^msb_pos), + // msb_pos is also the bit number of mantissa, 0 < msb_pos < sizeof(Ty) * 8, + // exponent of bfloat16 precision value range is [-126, 127]. + + uint16_t b_exp = msb_pos; + uint16_t b_mant; + + if (msb_pos <= 7) { + // No need to round off if we can losslessly fit the input value in + // mantissa of bfloat16. + mant <<= (7 - msb_pos); + b_mant = static_cast(mant); + } else { + b_mant = static_cast(mant >> (msb_pos - 7)); + T mant_discard = mant & ((static_cast(1) << (msb_pos - 7)) - 1); + T mid = static_cast(1) << (msb_pos - 8); + switch (roundingMode) { + case SYCLRoundingMode::automatic: + case SYCLRoundingMode::rte: + if ((mant_discard > mid) || + ((mant_discard == mid) && ((b_mant & 0x1) == 0x1))) + b_mant++; + break; + case SYCLRoundingMode::rtp: + if (mant_discard) + b_mant++; + break; + case SYCLRoundingMode::rtn: + case SYCLRoundingMode::rtz: + break; } - - b_exp += 127; - return bit_cast((b_exp << 7) | b_mant); + } + if (b_mant == 0x80) { + b_exp++; + b_mant = 0; } + b_exp += 127; + return bit_cast((b_exp << 7) | b_mant); +} + // Helper function to get BF16 from signed integral data types. // Reference: // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L353 - template - bfloat16 - getBFloat16FromSIntegralWithRoundingMode(T &i, - SYCLRoundingMode roundingMode) { - // Get unsigned type corresponding to T. - typedef typename std::make_unsigned_t UTy; - - uint16_t b_sign = (i >= 0) ? 0 : 0x8000; - UTy ui = (i > 0) ? static_cast(i) : static_cast(-i); - size_t msb_pos = get_msb_pos(ui); - if (msb_pos == 0) - return bit_cast(b_sign ? 0xBF80 : 0x3F80); - UTy mant = ui & ((static_cast(1) << msb_pos) - 1); - - uint16_t b_exp = msb_pos; - uint16_t b_mant; - if (msb_pos <= 7) { - mant <<= (7 - msb_pos); - b_mant = static_cast(mant); - } else { - b_mant = static_cast(mant >> (msb_pos - 7)); - T mant_discard = mant & ((static_cast(1) << (msb_pos - 7)) - 1); - T mid = static_cast(1) << (msb_pos - 8); - switch (roundingMode) { - case SYCLRoundingMode::automatic: - case SYCLRoundingMode::rte: - if ((mant_discard > mid) || - ((mant_discard == mid) && ((b_mant & 0x1) == 0x1))) - b_mant++; - break; - case SYCLRoundingMode::rtp: - if (mant_discard && !b_sign) - b_mant++; - break; - case SYCLRoundingMode::rtn: - if (mant_discard && b_sign) - b_mant++; - case SYCLRoundingMode::rtz: - break; - } +template +bfloat16 +getBFloat16FromSIntegralWithRoundingMode(T &i, SYCLRoundingMode roundingMode) { + // Get unsigned type corresponding to T. + typedef typename std::make_unsigned_t UTy; + + uint16_t b_sign = (i >= 0) ? 0 : 0x8000; + UTy ui = (i > 0) ? static_cast(i) : static_cast(-i); + size_t msb_pos = get_msb_pos(ui); + if (msb_pos == 0) + return bit_cast(b_sign ? 0xBF80 : 0x3F80); + UTy mant = ui & ((static_cast(1) << msb_pos) - 1); + + uint16_t b_exp = msb_pos; + uint16_t b_mant; + if (msb_pos <= 7) { + mant <<= (7 - msb_pos); + b_mant = static_cast(mant); + } else { + b_mant = static_cast(mant >> (msb_pos - 7)); + T mant_discard = mant & ((static_cast(1) << (msb_pos - 7)) - 1); + T mid = static_cast(1) << (msb_pos - 8); + switch (roundingMode) { + case SYCLRoundingMode::automatic: + case SYCLRoundingMode::rte: + if ((mant_discard > mid) || + ((mant_discard == mid) && ((b_mant & 0x1) == 0x1))) + b_mant++; + break; + case SYCLRoundingMode::rtp: + if (mant_discard && !b_sign) + b_mant++; + break; + case SYCLRoundingMode::rtn: + if (mant_discard && b_sign) + b_mant++; + case SYCLRoundingMode::rtz: + break; } + } - if (b_mant == 0x80) { - b_exp++; - b_mant = 0; - } - b_exp += 127; - return bit_cast(b_sign | (b_exp << 7) | b_mant); + if (b_mant == 0x80) { + b_exp++; + b_mant = 0; } + b_exp += 127; + return bit_cast(b_sign | (b_exp << 7) | b_mant); +} - template - bfloat16 getBfloat16WithRoundingMode(const Ty &a) { +template +bfloat16 getBfloat16WithRoundingMode(const Ty &a) { - if (!a) - return bfloat16{0.0f}; + if (!a) + return bfloat16{0.0f}; - constexpr SYCLRoundingMode roundingMode = static_cast(rm); + constexpr SYCLRoundingMode roundingMode = static_cast(rm); - // Float. - if constexpr (std::is_same_v) { - return getBFloat16FromFloatWithRoundingMode(a, roundingMode); - } - // Double. - else if constexpr (std::is_same_v) { - static_assert( - roundingMode == SYCLRoundingMode::automatic || - roundingMode == SYCLRoundingMode::rte, - "Only automatic/RTE rounding mode is supported for double type."); - return getBFloat16FromDoubleWithRTE(a); - } - // Half - else if constexpr (std::is_same_v) { - // Convert half to float and then convert to bfloat16. - // Conversion of half to float is lossless as the latter - // have a wider dynamic range. - return getBFloat16FromFloatWithRoundingMode(static_cast(a), - roundingMode); - } - // Unsigned integral types. - else if constexpr (std::is_integral_v && std::is_unsigned_v) { - return getBFloat16FromUIntegralWithRoundingMode(a, roundingMode); - } - // Signed integral types. - else if constexpr (std::is_integral_v && std::is_signed_v) { - return getBFloat16FromSIntegralWithRoundingMode(a, roundingMode); - } else { - static_assert(std::is_integral_v || std::is_floating_point_v, - "Only integral and floating point types are supported."); - } + // Float. + if constexpr (std::is_same_v) { + return getBFloat16FromFloatWithRoundingMode(a, roundingMode); + } + // Double. + else if constexpr (std::is_same_v) { + static_assert( + roundingMode == SYCLRoundingMode::automatic || + roundingMode == SYCLRoundingMode::rte, + "Only automatic/RTE rounding mode is supported for double type."); + return getBFloat16FromDoubleWithRTE(a); } + // Half + else if constexpr (std::is_same_v) { + // Convert half to float and then convert to bfloat16. + // Conversion of half to float is lossless as the latter + // have a wider dynamic range. + return getBFloat16FromFloatWithRoundingMode(static_cast(a), + roundingMode); + } + // Unsigned integral types. + else if constexpr (std::is_integral_v && std::is_unsigned_v) { + return getBFloat16FromUIntegralWithRoundingMode(a, roundingMode); + } + // Signed integral types. + else if constexpr (std::is_integral_v && std::is_signed_v) { + return getBFloat16FromSIntegralWithRoundingMode(a, roundingMode); + } else { + static_assert(std::is_integral_v || std::is_floating_point_v, + "Only integral and floating point types are supported."); + } +} } // namespace ConvertToBfloat16 } // namespace detail } // namespace ext::oneapi - } // namespace _V1 } // namespace sycl From 4e2a9cf7fe0c521f90043dc5200bd2567b4f9428 Mon Sep 17 00:00:00 2001 From: Alexey Bader Date: Fri, 28 Feb 2025 20:53:10 -0800 Subject: [PATCH 3/7] Add inline to function templates to avoid multiple definitions. --- sycl/include/sycl/ext/oneapi/bfloat16.hpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 56baaf7b5e7d4..188d2e4d48eb0 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -232,7 +232,8 @@ extern "C" __DPCPP_SYCL_EXTERNAL void __devicelib_ConvertBF16ToFINTELVec16(const uint16_t *, float *) noexcept; #endif -template void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) { +template +inline void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) { static_assert(N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16, "Unsupported vector size"); #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) @@ -272,7 +273,7 @@ extern "C" __DPCPP_SYCL_EXTERNAL void __devicelib_ConvertFToBF16INTELVec16(const float *, uint16_t *) noexcept; #endif -template void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) { +template inline void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) { static_assert(N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16, "Unsupported vector size"); #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) @@ -436,7 +437,7 @@ inline bfloat16 getBFloat16FromDoubleWithRTE(const double &d) { } // Function to get the most significant bit position of a number. -template size_t get_msb_pos(const Ty &x) { +template inline size_t get_msb_pos(const Ty &x) { assert(x != 0); size_t idx = 0; Ty mask = ((Ty)1 << (sizeof(Ty) * 8 - 1)); @@ -454,7 +455,7 @@ template size_t get_msb_pos(const Ty &x) { // Reference: // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L302 template -bfloat16 +inline bfloat16 getBFloat16FromUIntegralWithRoundingMode(T &u, SYCLRoundingMode roundingMode) { size_t msb_pos = get_msb_pos(u); @@ -508,7 +509,7 @@ getBFloat16FromUIntegralWithRoundingMode(T &u, SYCLRoundingMode roundingMode) { // Reference: // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L353 template -bfloat16 +inline bfloat16 getBFloat16FromSIntegralWithRoundingMode(T &i, SYCLRoundingMode roundingMode) { // Get unsigned type corresponding to T. typedef typename std::make_unsigned_t UTy; @@ -557,9 +558,8 @@ getBFloat16FromSIntegralWithRoundingMode(T &i, SYCLRoundingMode roundingMode) { } template -bfloat16 getBfloat16WithRoundingMode(const Ty &a) { - - if (!a) +inline bfloat16 getBfloat16WithRoundingMode(const Ty &a) { + if (a == 0) return bfloat16{0.0f}; constexpr SYCLRoundingMode roundingMode = static_cast(rm); From b5ad84c6a31b78bc15c7f550abfae9877e335973 Mon Sep 17 00:00:00 2001 From: Alexey Bader Date: Fri, 28 Feb 2025 20:53:48 -0800 Subject: [PATCH 4/7] Improve static assert message. --- sycl/include/sycl/ext/oneapi/bfloat16.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 188d2e4d48eb0..bb3d15a63b191 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -235,7 +235,7 @@ __devicelib_ConvertBF16ToFINTELVec16(const uint16_t *, float *) noexcept; template inline void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) { static_assert(N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16, - "Unsupported vector size"); + "Unsupported vector size for bfloat16 conversion"); #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) const uint16_t *src_i16 = sycl::bit_cast(src); if constexpr (N == 1) @@ -593,7 +593,8 @@ inline bfloat16 getBfloat16WithRoundingMode(const Ty &a) { return getBFloat16FromSIntegralWithRoundingMode(a, roundingMode); } else { static_assert(std::is_integral_v || std::is_floating_point_v, - "Only integral and floating point types are supported."); + "Only integral and floating-point types are supported for " + "conversion to bfloat16."); } } } // namespace ConvertToBfloat16 From c0fc5255ac16dedeed47f63e64db84e20fac3837 Mon Sep 17 00:00:00 2001 From: Alexey Bader Date: Fri, 28 Feb 2025 20:58:14 -0800 Subject: [PATCH 5/7] Add more comments. --- sycl/include/sycl/ext/oneapi/bfloat16.hpp | 30 ++++++++++++++++------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index bb3d15a63b191..42386500f7f4b 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -232,6 +232,10 @@ extern "C" __DPCPP_SYCL_EXTERNAL void __devicelib_ConvertBF16ToFINTELVec16(const uint16_t *, float *) noexcept; #endif +/// \brief Converts a vector of bfloat16 to a vector of floats. +/// \tparam N The size of the vector. Supported sizes are 1, 2, 3, 4, 8, and 16. +/// \param src The source vector of bfloat16. +/// \param dst The destination vector of floats. template inline void BF16VecToFloatVec(const bfloat16 src[N], float dst[N]) { static_assert(N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16, @@ -273,6 +277,10 @@ extern "C" __DPCPP_SYCL_EXTERNAL void __devicelib_ConvertFToBF16INTELVec16(const float *, uint16_t *) noexcept; #endif +/// \brief Converts a vector of floats to a vector of bfloat16. +/// \tparam N The size of the vector. +/// \param src The source vector of floats. +/// \param dst The destination vector of bfloat16. template inline void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) { static_assert(N == 1 || N == 2 || N == 3 || N == 4 || N == 8 || N == 16, "Unsupported vector size"); @@ -292,8 +300,8 @@ template inline void FloatVecToBF16Vec(float src[N], bfloat16 dst[N]) { __devicelib_ConvertFToBF16INTELVec16(src, dst_i16); #else for (int i = 0; i < N; ++i) { - // No need to cast as bfloat16 has a assignment op overload that takes - // a float. + // No need to cast as bfloat16 has an assignment operator overload that + // takes a float. dst[i] = src[i]; } #endif @@ -450,10 +458,10 @@ template inline size_t get_msb_pos(const Ty &x) { return (sizeof(Ty) * 8 - 1 - idx); } - // Helper function to get BF16 from unsigned integral data types - // with different rounding modes. - // Reference: - // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L302 +// Helper function to get BF16 from unsigned integral data types +// with different rounding modes. +// Reference: +// https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L302 template inline bfloat16 getBFloat16FromUIntegralWithRoundingMode(T &u, SYCLRoundingMode roundingMode) { @@ -505,9 +513,9 @@ getBFloat16FromUIntegralWithRoundingMode(T &u, SYCLRoundingMode roundingMode) { return bit_cast((b_exp << 7) | b_mant); } - // Helper function to get BF16 from signed integral data types. - // Reference: - // https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L353 +// Helper function to get BF16 from signed integral data types. +// Reference: +// https://github.com/intel/llvm/blob/sycl/libdevice/imf_bf16.hpp#L353 template inline bfloat16 getBFloat16FromSIntegralWithRoundingMode(T &i, SYCLRoundingMode roundingMode) { @@ -557,6 +565,10 @@ getBFloat16FromSIntegralWithRoundingMode(T &i, SYCLRoundingMode roundingMode) { return bit_cast(b_sign | (b_exp << 7) | b_mant); } +/// \brief Converts a given value to bfloat16 with a specified rounding mode. +/// \tparam rm The rounding mode to be used for conversion. +/// \param a The input value to be converted. +/// \return The converted bfloat16 value. template inline bfloat16 getBfloat16WithRoundingMode(const Ty &a) { if (a == 0) From 3adc510d5e079fdefe3cb8cdda3a43a220965be2 Mon Sep 17 00:00:00 2001 From: Alexey Bader Date: Sat, 1 Mar 2025 10:18:35 -0800 Subject: [PATCH 6/7] Revert. --- sycl/include/sycl/ext/oneapi/bfloat16.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 42386500f7f4b..0fd80320b1fe2 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -19,7 +19,7 @@ namespace sycl { inline namespace _V1 { namespace ext::oneapi { -class bfloat16 final { +class bfloat16 { public: using Bfloat16StorageT = uint16_t; From 28dee03a08b86227fc92252972f426cd06b86944 Mon Sep 17 00:00:00 2001 From: Alexey Bader Date: Mon, 3 Mar 2025 10:28:05 -0800 Subject: [PATCH 7/7] Move code comments. --- sycl/include/sycl/ext/oneapi/bfloat16.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/bfloat16.hpp b/sycl/include/sycl/ext/oneapi/bfloat16.hpp index 0fd80320b1fe2..69d8cf3a7f366 100644 --- a/sycl/include/sycl/ext/oneapi/bfloat16.hpp +++ b/sycl/include/sycl/ext/oneapi/bfloat16.hpp @@ -154,6 +154,8 @@ class bfloat16 { friend class sycl::vec; }; +// Conversion functions for bfloat16 + // Helper functions for conversions between bfloat16 and float scalar types. #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) extern "C" __DPCPP_SYCL_EXTERNAL float @@ -214,8 +216,6 @@ inline bfloat16::Bfloat16StorageT bfloat16::from_float(const float &a) { } namespace detail { -// Conversion functions for bfloat16 - // Helper functions for vector conversions from bfloat16 to float #if defined(__SYCL_DEVICE_ONLY__) && (defined(__SPIR__) || defined(__SPIRV__)) extern "C" __DPCPP_SYCL_EXTERNAL void