From 36be088c0e84f7f58e5c20e9a0babd56f036ab5a Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 15 Aug 2025 08:55:31 -0400 Subject: [PATCH] [FFI] Make JSON Parser/Write fastmath safe This PR adds fallbacks for nan and inf detection/creation under fastmath mode. --- ffi/src/ffi/extra/json_parser.cc | 33 ++++++++++++++++++++-- ffi/src/ffi/extra/json_writer.cc | 35 +++++++++++++++++++++-- ffi/tests/cpp/extra/test_json_parser.cc | 37 +++++++++++++++++++++++-- 3 files changed, 97 insertions(+), 8 deletions(-) diff --git a/ffi/src/ffi/extra/json_parser.cc b/ffi/src/ffi/extra/json_parser.cc index dd3fae351d21..3107e4ddf1ad 100644 --- a/ffi/src/ffi/extra/json_parser.cc +++ b/ffi/src/ffi/extra/json_parser.cc @@ -167,7 +167,7 @@ class JSONParserContext { ++cur_; if (cur_ != end_ && *cur_ == 'I') { if (this->MatchLiteral("Infinity", 8)) { - *out = -std::numeric_limits::infinity(); + *out = FastMathSafeNegInf(); return true; } else { this->SetCurrentPosForBetterErrorMsg(start_pos); @@ -177,7 +177,7 @@ class JSONParserContext { } } else if (*cur_ == 'I') { if (this->MatchLiteral("Infinity", 8)) { - *out = std::numeric_limits::infinity(); + *out = FastMathSafePosInf(); return true; } else { this->SetCurrentPosForBetterErrorMsg(start_pos); @@ -186,7 +186,7 @@ class JSONParserContext { } } else if (*cur_ == 'N') { if (this->MatchLiteral("NaN", 3)) { - *out = std::numeric_limits::quiet_NaN(); + *out = FastMathSafeNaN(); return true; } else { this->SetCurrentPosForBetterErrorMsg(start_pos); @@ -296,6 +296,33 @@ class JSONParserContext { void SetErrorExpectingComma() { error_msg_ = GetSyntaxErrorContext("Expecting \',\' delimiter"); } private: + static double FastMathSafePosInf() { +#ifdef __FAST_MATH__ + const uint64_t inf_bits = 0x7FF0000000000000ULL; + return *reinterpret_cast(&inf_bits); +#else + return std::numeric_limits::infinity(); +#endif + } + + static double FastMathSafeNegInf() { +#ifdef __FAST_MATH__ + const uint64_t inf_bits = 0xFFF0000000000000ULL; + return *reinterpret_cast(&inf_bits); +#else + return -std::numeric_limits::infinity(); +#endif + } + + static double FastMathSafeNaN() { +#ifdef __FAST_MATH__ + const uint64_t nan_bits = 0x7FF8000000000000ULL; + return *reinterpret_cast(&nan_bits); +#else + return std::numeric_limits::quiet_NaN(); +#endif + } + // Full string parsing with escape and unicode handling bool NextStringWithFullHandling(Any* out, const char* start_pos) { // copy over the prefix that was already parsed diff --git a/ffi/src/ffi/extra/json_writer.cc b/ffi/src/ffi/extra/json_writer.cc index 94ba5e4a5a12..81d321d9a754 100644 --- a/ffi/src/ffi/extra/json_writer.cc +++ b/ffi/src/ffi/extra/json_writer.cc @@ -60,6 +60,37 @@ class JSONWriter { private: explicit JSONWriter(int indent) : indent_(indent), out_iter_(result_) {} + static bool FastMathSafeIsNaN(double x) { +#ifdef __FAST_MATH__ + // Bit-level NaN detection (IEEE 754 double) + // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 + // NaN is encoded as all 1s in the exponent and non-zero in the mantissa + static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); + uint64_t bits = *reinterpret_cast(&x); + uint64_t exponent = (bits >> 52) & 0x7FF; + uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; + return (exponent == 0x7FF) && (mantissa != 0); +#else + // Safe to use std::isnan when fast-math is off + return std::isnan(x); +#endif + } + + static bool FastMathSafeIsInf(double x) { +#ifdef __FAST_MATH__ + // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 + // Inf is encoded as all 1s in the exponent and zero in the mantissa + static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); + uint64_t bits = *reinterpret_cast(&x); + uint64_t exponent = (bits >> 52) & 0x7FF; + uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; + // inf is encoded as all 1s in the exponent and zero in the mantissa + return (exponent == 0x7FF) && (mantissa == 0); +#else + return std::isinf(x); +#endif + } + void WriteValue(const json::Value& value) { switch (value.type_index()) { case TypeIndex::kTVMFFINone: { @@ -120,9 +151,9 @@ class JSONWriter { // largest possible string representation of a double is around 24 chars plus // one null terminator keep 32 to be safe char buffer[32]; - if (std::isnan(value)) { + if (FastMathSafeIsNaN(value)) { WriteLiteral("NaN", 3); - } else if (std::isinf(value)) { + } else if (FastMathSafeIsInf(value)) { if (value < 0) { WriteLiteral("-Infinity", 9); } else { diff --git a/ffi/tests/cpp/extra/test_json_parser.cc b/ffi/tests/cpp/extra/test_json_parser.cc index c0332e6f8f20..a1cc2800094f 100644 --- a/ffi/tests/cpp/extra/test_json_parser.cc +++ b/ffi/tests/cpp/extra/test_json_parser.cc @@ -28,6 +28,37 @@ namespace { using namespace tvm::ffi; +inline bool FastMathSafeIsNaN(double x) { +#ifdef __FAST_MATH__ + // Bit-level NaN detection (IEEE 754 double) + // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 + // NaN is encoded as all 1s in the exponent and non-zero in the mantissa + static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); + uint64_t bits = *reinterpret_cast(&x); + uint64_t exponent = (bits >> 52) & 0x7FF; + uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; + return (exponent == 0x7FF) && (mantissa != 0); +#else + // Safe to use std::isnan when fast-math is off + return std::isnan(x); +#endif +} + +inline bool FastMathSafeIsInf(double x) { +#ifdef __FAST_MATH__ + // IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754 + // Inf is encoded as all 1s in the exponent and zero in the mantissa + static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size"); + uint64_t bits = *reinterpret_cast(&x); + uint64_t exponent = (bits >> 52) & 0x7FF; + uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull; + // inf is encoded as all 1s in the exponent and zero in the mantissa + return (exponent == 0x7FF) && (mantissa == 0); +#else + return std::isinf(x); +#endif +} + TEST(JSONParser, BoolNull) { // boolean value EXPECT_EQ(json::Parse("true").cast(), true); @@ -61,11 +92,11 @@ TEST(JSONParser, Number) { // parsing scientific notation EXPECT_EQ(json::Parse("1.456e12").cast(), 1.456e12); // NaN - EXPECT_EQ(std::isnan(json::Parse("NaN").cast()), true); + EXPECT_EQ(FastMathSafeIsNaN(json::Parse("NaN").cast()), true); // Infinity - EXPECT_EQ(std::isinf(json::Parse("Infinity").cast()), true); + EXPECT_EQ(FastMathSafeIsInf(json::Parse("Infinity").cast()), true); // -Infinity - EXPECT_EQ(std::isinf(-json::Parse("-Infinity").cast()), true); + EXPECT_EQ(FastMathSafeIsInf(-json::Parse("-Infinity").cast()), true); // Test zero variants EXPECT_EQ(json::Parse("0").cast(), 0);