diff --git a/CMakeLists.txt b/CMakeLists.txt index 9af53c49..c7fc9de6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,8 @@ elseif(UNIX) target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE CurlClient) endif() +target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE vendor/onnxruntime/include) + target_sources( ${CMAKE_PROJECT_NAME} PRIVATE src/plugin-main.c diff --git a/cmake/FetchOnnxruntime.cmake b/cmake/FetchOnnxruntime.cmake index bf9d25cc..f9231f61 100644 --- a/cmake/FetchOnnxruntime.cmake +++ b/cmake/FetchOnnxruntime.cmake @@ -18,14 +18,14 @@ else() endif() endif() -set(Onnxruntime_VERSION "1.15.1") +set(Onnxruntime_VERSION "1.16.0") if(OS_MACOS) if(USE_PREDEFINED_ONNXRUNTIME) FetchContent_Declare( Onnxruntime URL "https://github.com/microsoft/onnxruntime/releases/download/v${Onnxruntime_VERSION}/onnxruntime-osx-universal2-${Onnxruntime_VERSION}.tgz" - URL_HASH MD5=6637c31d2dae5ad8bffb4da38ae4d89e) + URL_HASH MD5=6f6adb0ad879999e089ea017ab457382) else() FetchContent_Declare( Onnxruntime @@ -49,8 +49,8 @@ elseif(OS_WINDOWS) if(USE_PREDEFINED_ONNXRUNTIME) FetchContent_Declare( Onnxruntime - URL "https://github.com/umireon/onnxruntime-static-win/releases/download/v${Onnxruntime_VERSION}-1/onnxruntime-static-win.zip" - URL_HASH MD5=1be38fe8abe304085b9eb58272fbc941) + URL "https://github.com/umireon/onnxruntime-static-win/releases/download/v${Onnxruntime_VERSION}-1/onnxruntime-windows-Release.zip" + URL_HASH MD5=bcd67774a7dbdc3e6dedcbeb22b8d516) else() FetchContent_Declare( Onnxruntime @@ -94,7 +94,7 @@ elseif(OS_LINUX) FetchContent_Declare( Onnxruntime URL "https://github.com/microsoft/onnxruntime/releases/download/v${Onnxruntime_VERSION}/onnxruntime-linux-aarch64-${Onnxruntime_VERSION}.tgz" - URL_HASH MD5=ebec0b185c9bec94fde884a97b144c04) + URL_HASH MD5=c2112b66c99f9b85d4d71e5c61712034) else() FetchContent_Declare( Onnxruntime @@ -109,7 +109,7 @@ elseif(OS_LINUX) FetchContent_Declare( Onnxruntime URL "https://github.com/microsoft/onnxruntime/releases/download/v${Onnxruntime_VERSION}/onnxruntime-linux-x64-gpu-${Onnxruntime_VERSION}.tgz" - URL_HASH MD5=8d2f5ee9f449bdecb10a45715fe74c53) + URL_HASH MD5=21b3eb42e0bee9e7b51cbf508352f188) else() FetchContent_Declare( Onnxruntime diff --git a/vendor/onnxruntime/include/onnxruntime_float16.h b/vendor/onnxruntime/include/onnxruntime_float16.h new file mode 100644 index 00000000..0b066a9c --- /dev/null +++ b/vendor/onnxruntime/include/onnxruntime_float16.h @@ -0,0 +1,540 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +namespace onnxruntime_float16 { + +namespace detail { + +enum class endian { +#if defined(_WIN32) + little = 0, + big = 1, + native = little, +#elif defined(__GNUC__) || defined(__clang__) + little = __ORDER_LITTLE_ENDIAN__, + big = __ORDER_BIG_ENDIAN__, + native = __BYTE_ORDER__, +#else +#error onnxruntime_float16::detail::endian is not implemented in this environment. +#endif +}; + +static_assert( + endian::native == endian::little || endian::native == endian::big, + "Only little-endian or big-endian native byte orders are supported."); + +} // namespace detail + +/// +/// Shared implementation between public and internal classes. CRTP pattern. +/// +template +struct Float16Impl { + protected: + /// + /// Converts from float to uint16_t float16 representation + /// + /// + /// + constexpr static uint16_t ToUint16Impl(float v) noexcept; + + /// + /// Converts float16 to float + /// + /// float representation of float16 value + float ToFloatImpl() const noexcept; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + uint16_t AbsImpl() const noexcept { + return static_cast(val & ~kSignMask); + } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + uint16_t NegateImpl() const noexcept { + return IsNaN() ? val : static_cast(val ^ kSignMask); + } + + public: + // uint16_t special values + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kBiasedExponentMask = 0x7C00U; + static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; + static constexpr uint16_t kNegativeInfinityBits = 0xFC00U; + static constexpr uint16_t kPositiveQNaNBits = 0x7E00U; + static constexpr uint16_t kNegativeQNaNBits = 0xFE00U; + static constexpr uint16_t kEpsilonBits = 0x4170U; + static constexpr uint16_t kMinValueBits = 0xFBFFU; // Minimum normal number + static constexpr uint16_t kMaxValueBits = 0x7BFFU; // Largest normal number + static constexpr uint16_t kOneBits = 0x3C00U; + static constexpr uint16_t kMinusOneBits = 0xBC00U; + + uint16_t val{0}; + + Float16Impl() = default; + + /// + /// Checks if the value is negative + /// + /// true if negative + bool IsNegative() const noexcept { + return static_cast(val) < 0; + } + + /// + /// Tests if the value is NaN + /// + /// true if NaN + bool IsNaN() const noexcept { + return AbsImpl() > kPositiveInfinityBits; + } + + /// + /// Tests if the value is finite + /// + /// true if finite + bool IsFinite() const noexcept { + return AbsImpl() < kPositiveInfinityBits; + } + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + bool IsPositiveInfinity() const noexcept { + return val == kPositiveInfinityBits; + } + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + bool IsNegativeInfinity() const noexcept { + return val == kNegativeInfinityBits; + } + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + bool IsInfinity() const noexcept { + return AbsImpl() == kPositiveInfinityBits; + } + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + bool IsNaNOrZero() const noexcept { + auto abs = AbsImpl(); + return (abs == 0 || abs > kPositiveInfinityBits); + } + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + bool IsNormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) + } + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + bool IsSubnormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) + } + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + static bool AreZero(const Float16Impl& lhs, const Float16Impl& rhs) noexcept { + return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; + } + + bool operator==(const Float16Impl& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is not equal to anything, including itself. + return false; + } + return val == rhs.val; + } + + bool operator!=(const Float16Impl& rhs) const noexcept { return !(*this == rhs); } + + bool operator<(const Float16Impl& rhs) const noexcept { + if (IsNaN() || rhs.IsNaN()) { + // IEEE defines that NaN is unordered with respect to everything, including itself. + return false; + } + + const bool left_is_negative = IsNegative(); + if (left_is_negative != rhs.IsNegative()) { + // When the signs of left and right differ, we know that left is less than right if it is + // the negative value. The exception to this is if both values are zero, in which case IEEE + // says they should be equal, even if the signs differ. + return left_is_negative && !AreZero(*this, rhs); + } + return (val != rhs.val) && ((val < rhs.val) ^ left_is_negative); + } +}; + +// The following Float16_t conversions are based on the code from +// Eigen library. + +// The conversion routines are Copyright (c) Fabian Giesen, 2016. +// The original license follows: +// +// Copyright (c) Fabian Giesen, 2016 +// All rights reserved. +// Redistribution and use in source and binary forms, with or without +// modification, are permitted. +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +namespace detail { +union float32_bits { + unsigned int u; + float f; +}; +} // namespace detail + +template +inline constexpr uint16_t Float16Impl::ToUint16Impl(float v) noexcept { + detail::float32_bits f{}; + f.f = v; + + constexpr detail::float32_bits f32infty = {255 << 23}; + constexpr detail::float32_bits f16max = {(127 + 16) << 23}; + constexpr detail::float32_bits denorm_magic = {((127 - 15) + (23 - 10) + 1) << 23}; + constexpr unsigned int sign_mask = 0x80000000u; + uint16_t val = static_cast(0x0u); + + unsigned int sign = f.u & sign_mask; + f.u ^= sign; + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code + // (since there's no unsigned PCMPGTD). + + if (f.u >= f16max.u) { // result is Inf or NaN (all exponent bits set) + val = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f.u < (113 << 23)) { // resulting FP16 is subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.f += denorm_magic.f; + + // and one integer subtract of the bias later, we have our final float! + val = static_cast(f.u - denorm_magic.u); + } else { + unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + // Equivalent to `f.u += ((unsigned int)(15 - 127) << 23) + 0xfff`, but + // without arithmetic overflow. + f.u += 0xc8000fffU; + // rounding bias part 2 + f.u += mant_odd; + // take the bits! + val = static_cast(f.u >> 13); + } + } + + val |= static_cast(sign >> 16); + return val; +} + +template +inline float Float16Impl::ToFloatImpl() const noexcept { + constexpr detail::float32_bits magic = {113 << 23}; + constexpr unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + detail::float32_bits o{}; + + o.u = (val & 0x7fff) << 13; // exponent/mantissa bits + unsigned int exp = shifted_exp & o.u; // just the exponent + o.u += (127 - 15) << 23; // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.u += (128 - 16) << 23; // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.u += 1 << 23; // extra exp adjust + o.f -= magic.f; // re-normalize + } + + // Attempt to workaround the Internal Compiler Error on ARM64 + // for bitwise | operator, including std::bitset +#if (defined _MSC_VER) && (defined _M_ARM || defined _M_ARM64 || defined _M_ARM64EC) + if (IsNegative()) { + return -o.f; + } +#else + // original code: + o.u |= (val & 0x8000U) << 16U; // sign bit +#endif + return o.f; +} + +/// Shared implementation between public and internal classes. CRTP pattern. +template +struct BFloat16Impl { + protected: + /// + /// Converts from float to uint16_t float16 representation + /// + /// + /// + static uint16_t ToUint16Impl(float v) noexcept; + + /// + /// Converts bfloat16 to float + /// + /// float representation of bfloat16 value + float ToFloatImpl() const noexcept; + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + uint16_t AbsImpl() const noexcept { + return static_cast(val & ~kSignMask); + } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + uint16_t NegateImpl() const noexcept { + return IsNaN() ? val : static_cast(val ^ kSignMask); + } + + public: + // uint16_t special values + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kBiasedExponentMask = 0x7F80U; + static constexpr uint16_t kPositiveInfinityBits = 0x7F80U; + static constexpr uint16_t kNegativeInfinityBits = 0xFF80U; + static constexpr uint16_t kPositiveQNaNBits = 0x7FC1U; + static constexpr uint16_t kNegativeQNaNBits = 0xFFC1U; + static constexpr uint16_t kSignaling_NaNBits = 0x7F80U; + static constexpr uint16_t kEpsilonBits = 0x0080U; + static constexpr uint16_t kMinValueBits = 0xFF7FU; + static constexpr uint16_t kMaxValueBits = 0x7F7FU; + static constexpr uint16_t kRoundToNearest = 0x7FFFU; + static constexpr uint16_t kOneBits = 0x3F80U; + static constexpr uint16_t kMinusOneBits = 0xBF80U; + + uint16_t val{0}; + + BFloat16Impl() = default; + + /// + /// Checks if the value is negative + /// + /// true if negative + bool IsNegative() const noexcept { + return static_cast(val) < 0; + } + + /// + /// Tests if the value is NaN + /// + /// true if NaN + bool IsNaN() const noexcept { + return AbsImpl() > kPositiveInfinityBits; + } + + /// + /// Tests if the value is finite + /// + /// true if finite + bool IsFinite() const noexcept { + return AbsImpl() < kPositiveInfinityBits; + } + + /// + /// Tests if the value represents positive infinity. + /// + /// true if positive infinity + bool IsPositiveInfinity() const noexcept { + return val == kPositiveInfinityBits; + } + + /// + /// Tests if the value represents negative infinity + /// + /// true if negative infinity + bool IsNegativeInfinity() const noexcept { + return val == kNegativeInfinityBits; + } + + /// + /// Tests if the value is either positive or negative infinity. + /// + /// True if absolute value is infinity + bool IsInfinity() const noexcept { + return AbsImpl() == kPositiveInfinityBits; + } + + /// + /// Tests if the value is NaN or zero. Useful for comparisons. + /// + /// True if NaN or zero. + bool IsNaNOrZero() const noexcept { + auto abs = AbsImpl(); + return (abs == 0 || abs > kPositiveInfinityBits); + } + + /// + /// Tests if the value is normal (not zero, subnormal, infinite, or NaN). + /// + /// True if so + bool IsNormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) != 0); // is not subnormal (has a non-zero exponent) + } + + /// + /// Tests if the value is subnormal (denormal). + /// + /// True if so + bool IsSubnormal() const noexcept { + auto abs = AbsImpl(); + return (abs < kPositiveInfinityBits) // is finite + && (abs != 0) // is not zero + && ((abs & kBiasedExponentMask) == 0); // is subnormal (has a zero exponent) + } + + /// + /// Creates an instance that represents absolute value. + /// + /// Absolute value + Derived Abs() const noexcept { return Derived::FromBits(AbsImpl()); } + + /// + /// Creates a new instance with the sign flipped. + /// + /// Flipped sign instance + Derived Negate() const noexcept { return Derived::FromBits(NegateImpl()); } + + /// + /// IEEE defines that positive and negative zero are equal, this gives us a quick equality check + /// for two values by or'ing the private bits together and stripping the sign. They are both zero, + /// and therefore equivalent, if the resulting value is still zero. + /// + /// first value + /// second value + /// True if both arguments represent zero + static bool AreZero(const BFloat16Impl& lhs, const BFloat16Impl& rhs) noexcept { + // IEEE defines that positive and negative zero are equal, this gives us a quick equality check + // for two values by or'ing the private bits together and stripping the sign. They are both zero, + // and therefore equivalent, if the resulting value is still zero. + return static_cast((lhs.val | rhs.val) & ~kSignMask) == 0; + } +}; + +template +inline uint16_t BFloat16Impl::ToUint16Impl(float v) noexcept { + uint16_t result; + if (std::isnan(v)) { + result = kPositiveQNaNBits; + } else { + auto get_msb_half = [](float fl) { + uint16_t result; +#ifdef __cpp_if_constexpr + if constexpr (detail::endian::native == detail::endian::little) { +#else + if (detail::endian::native == detail::endian::little) { +#endif + std::memcpy(&result, reinterpret_cast(&fl) + sizeof(uint16_t), sizeof(uint16_t)); + } else { + std::memcpy(&result, &fl, sizeof(uint16_t)); + } + return result; + }; + + uint16_t upper_bits = get_msb_half(v); + union { + uint32_t U32; + float F32; + }; + F32 = v; + U32 += (upper_bits & 1) + kRoundToNearest; + result = get_msb_half(F32); + } + return result; +} + +template +inline float BFloat16Impl::ToFloatImpl() const noexcept { + if (IsNaN()) { + return std::numeric_limits::quiet_NaN(); + } + float result; + char* const first = reinterpret_cast(&result); + char* const second = first + sizeof(uint16_t); +#ifdef __cpp_if_constexpr + if constexpr (detail::endian::native == detail::endian::little) { +#else + if (detail::endian::native == detail::endian::little) { +#endif + std::memset(first, 0, sizeof(uint16_t)); + std::memcpy(second, &val, sizeof(uint16_t)); + } else { + std::memcpy(first, &val, sizeof(uint16_t)); + std::memset(second, 0, sizeof(uint16_t)); + } + return result; +} + +} // namespace onnxruntime_float16