From 0dec30d6a63f9b2eeb3939a71b6c475f31cf0db6 Mon Sep 17 00:00:00 2001 From: beetrees Date: Sat, 13 Jul 2024 07:50:08 +0100 Subject: [PATCH] [APFloat] Fix `IEEEFloat::addOrSubtractSignificand` and `IEEEFloat::normalize` --- llvm/include/llvm/ADT/APFloat.h | 2 + llvm/lib/Support/APFloat.cpp | 50 +++++--- llvm/unittests/ADT/APFloatTest.cpp | 179 +++++++++++++++++++++++++++++ 3 files changed, 216 insertions(+), 15 deletions(-) diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h index 4ca928bf4f49e3..a561164ed13be2 100644 --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -750,6 +750,8 @@ class IEEEFloat final { /// Sign bit of the number. unsigned int sign : 1; + + friend class IEEEFloatUnitTestHelper; }; hash_code hash_value(const IEEEFloat &Arg); diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index 81e297c3ab033e..494c9f1049cdca 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -1682,7 +1682,8 @@ APFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode, /* Before rounding normalize the exponent of fcNormal numbers. */ omsb = significandMSB() + 1; - if (omsb) { + // Only skip this `if` if the value is exactly zero. + if (omsb || lost_fraction != lfExactlyZero) { /* OMSB is numbered from 1. We want to place it in the integer bit numbered PRECISION if possible, with a compensating change in the exponent. */ @@ -1864,7 +1865,7 @@ APFloat::opStatus IEEEFloat::addOrSubtractSpecials(const IEEEFloat &rhs, /* Add or subtract two normal numbers. */ lostFraction IEEEFloat::addOrSubtractSignificand(const IEEEFloat &rhs, bool subtract) { - integerPart carry; + integerPart carry = 0; lostFraction lost_fraction; int bits; @@ -1882,11 +1883,13 @@ lostFraction IEEEFloat::addOrSubtractSignificand(const IEEEFloat &rhs, "This floating point format does not support signed values"); IEEEFloat temp_rhs(rhs); + bool lost_fraction_is_from_rhs = false; if (bits == 0) lost_fraction = lfExactlyZero; else if (bits > 0) { lost_fraction = temp_rhs.shiftSignificandRight(bits - 1); + lost_fraction_is_from_rhs = true; shiftSignificandLeft(1); } else { lost_fraction = shiftSignificandRight(-bits - 1); @@ -1894,23 +1897,40 @@ lostFraction IEEEFloat::addOrSubtractSignificand(const IEEEFloat &rhs, } // Should we reverse the subtraction. - if (compareAbsoluteValue(temp_rhs) == cmpLessThan) { - carry = temp_rhs.subtractSignificand - (*this, lost_fraction != lfExactlyZero); + cmpResult cmp_result = compareAbsoluteValue(temp_rhs); + if (cmp_result == cmpLessThan) { + bool borrow = + lost_fraction != lfExactlyZero && !lost_fraction_is_from_rhs; + if (borrow) { + // The lost fraction is being subtracted, borrow from the significand + // and invert `lost_fraction`. + if (lost_fraction == lfLessThanHalf) + lost_fraction = lfMoreThanHalf; + else if (lost_fraction == lfMoreThanHalf) + lost_fraction = lfLessThanHalf; + } + carry = temp_rhs.subtractSignificand(*this, borrow); copySignificand(temp_rhs); sign = !sign; - } else { - carry = subtractSignificand - (temp_rhs, lost_fraction != lfExactlyZero); + } else if (cmp_result == cmpGreaterThan) { + bool borrow = lost_fraction != lfExactlyZero && lost_fraction_is_from_rhs; + if (borrow) { + // The lost fraction is being subtracted, borrow from the significand + // and invert `lost_fraction`. + if (lost_fraction == lfLessThanHalf) + lost_fraction = lfMoreThanHalf; + else if (lost_fraction == lfMoreThanHalf) + lost_fraction = lfLessThanHalf; + } + carry = subtractSignificand(temp_rhs, borrow); + } else { // cmpEqual + zeroSignificand(); + if (lost_fraction != lfExactlyZero && lost_fraction_is_from_rhs) { + // rhs is slightly larger due to the lost fraction, flip the sign. + sign = !sign; + } } - /* Invert the lost fraction - it was on the RHS and - subtracted. */ - if (lost_fraction == lfLessThanHalf) - lost_fraction = lfMoreThanHalf; - else if (lost_fraction == lfMoreThanHalf) - lost_fraction = lfLessThanHalf; - /* The code above is intended to ensure that no borrow is necessary. */ assert(!carry); diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp index f291c814886d35..0023b87ef82e1b 100644 --- a/llvm/unittests/ADT/APFloatTest.cpp +++ b/llvm/unittests/ADT/APFloatTest.cpp @@ -47,6 +47,37 @@ static std::string convertToString(double d, unsigned Prec, unsigned Pad, return std::string(Buffer.data(), Buffer.size()); } +namespace llvm { +namespace detail { +class IEEEFloatUnitTestHelper { +public: + static void runTest(bool subtract, bool lhsSign, + APFloat::ExponentType lhsExponent, + APFloat::integerPart lhsSignificand, bool rhsSign, + APFloat::ExponentType rhsExponent, + APFloat::integerPart rhsSignificand, bool expectedSign, + APFloat::ExponentType expectedExponent, + APFloat::integerPart expectedSignificand, + lostFraction expectedLoss) { + // `addOrSubtractSignificand` only uses the sign, exponent and significand + IEEEFloat lhs(1.0); + lhs.sign = lhsSign; + lhs.exponent = lhsExponent; + lhs.significand.part = lhsSignificand; + IEEEFloat rhs(1.0); + rhs.sign = rhsSign; + rhs.exponent = rhsExponent; + rhs.significand.part = rhsSignificand; + lostFraction resultLoss = lhs.addOrSubtractSignificand(rhs, subtract); + EXPECT_EQ(resultLoss, expectedLoss); + EXPECT_EQ(lhs.sign, expectedSign); + EXPECT_EQ(lhs.exponent, expectedExponent); + EXPECT_EQ(lhs.significand.part, expectedSignificand); + } +}; +} // namespace detail +} // namespace llvm + namespace { TEST(APFloatTest, isSignaling) { @@ -560,6 +591,104 @@ TEST(APFloatTest, FMA) { EXPECT_EQ(-8.85242279E-41f, f1.convertToFloat()); } + // The `addOrSubtractSignificand` can be considered to have 9 possible cases + // when subtracting: all combinations of {cmpLessThan, cmpGreaterThan, + // cmpEqual} and {no loss, loss from lhs, loss from rhs}. Test each reachable + // case here. + + // Regression test for failing the `assert(!carry)` in + // `addOrSubtractSignificand` and normalizing the exponent even when the + // significand is zero if there is a lost fraction. + // This tests cmpEqual, loss from lhs + { + APFloat f1(-1.4728589E-38f); + APFloat f2(3.7105144E-6f); + APFloat f3(5.5E-44f); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(-0.0f, f1.convertToFloat()); + } + + // Test cmpGreaterThan, no loss + { + APFloat f1(2.0f); + APFloat f2(2.0f); + APFloat f3(-3.5f); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(0.5f, f1.convertToFloat()); + } + + // Test cmpLessThan, no loss + { + APFloat f1(2.0f); + APFloat f2(2.0f); + APFloat f3(-4.5f); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(-0.5f, f1.convertToFloat()); + } + + // Test cmpEqual, no loss + { + APFloat f1(2.0f); + APFloat f2(2.0f); + APFloat f3(-4.0f); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(0.0f, f1.convertToFloat()); + } + + // Test cmpLessThan, loss from lhs + { + APFloat f1(2.0000002f); + APFloat f2(2.0000002f); + APFloat f3(-32.0f); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(-27.999998f, f1.convertToFloat()); + } + + // Test cmpGreaterThan, loss from rhs + { + APFloat f1(1e10f); + APFloat f2(1e10f); + APFloat f3(-2.0000002f); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(1e20f, f1.convertToFloat()); + } + + // Test cmpGreaterThan, loss from lhs + { + APFloat f1(1e-36f); + APFloat f2(0.0019531252f); + APFloat f3(-1e-45f); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(1.953124e-39f, f1.convertToFloat()); + } + + // {cmpEqual, cmpLessThan} with loss from rhs can't occur for the usage in + // `fusedMultiplyAdd` as `multiplySignificand` normalises the MSB of lhs to + // one bit below the top. + + // Test cases from #104984 + { + APFloat f1(0.24999998f); + APFloat f2(2.3509885e-38f); + APFloat f3(-1e-45f); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(5.87747e-39f, f1.convertToFloat()); + } + { + APFloat f1(4.4501477170144023e-308); + APFloat f2(0.24999999999999997); + APFloat f3(-8.475904604373977e-309); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(2.64946468816203e-309, f1.convertToDouble()); + } + { + APFloat f1(APFloat::IEEEhalf(), APInt(16, 0x8fffu)); + APFloat f2(APFloat::IEEEhalf(), APInt(16, 0x2bffu)); + APFloat f3(APFloat::IEEEhalf(), APInt(16, 0x0172u)); + f1.fusedMultiplyAdd(f2, f3, APFloat::rmNearestTiesToEven); + EXPECT_EQ(0x808eu, f1.bitcastToAPInt().getZExtValue()); + } + // Test using only a single instance of APFloat. { APFloat F(1.5); @@ -8089,4 +8218,54 @@ TEST(APFloatTest, Float4E2M1FNToFloat) { EXPECT_TRUE(SmallestDenorm.isDenormal()); EXPECT_EQ(0x0.8p0, SmallestDenorm.convertToFloat()); } + +TEST(APFloatTest, AddOrSubtractSignificand) { + typedef detail::IEEEFloatUnitTestHelper Helper; + // Test cases are all combinations of: + // {equal exponents, LHS larger exponent, RHS larger exponent} + // {equal significands, LHS larger significand, RHS larger significand} + // {no loss, loss} + + // Equal exponents (loss cannot occur as their is no shifting) + Helper::runTest(true, false, 1, 0x10, false, 1, 0x5, false, 1, 0xb, + lfExactlyZero); + Helper::runTest(false, false, -2, 0x20, true, -2, 0x20, false, -2, 0, + lfExactlyZero); + Helper::runTest(false, true, 3, 0x20, false, 3, 0x30, false, 3, 0x10, + lfExactlyZero); + + // LHS larger exponent + // LHS significand greater after shitfing + Helper::runTest(true, false, 7, 0x100, false, 3, 0x100, false, 6, 0x1e0, + lfExactlyZero); + Helper::runTest(true, false, 7, 0x100, false, 3, 0x101, false, 6, 0x1df, + lfMoreThanHalf); + // Significands equal after shitfing + Helper::runTest(true, false, 7, 0x100, false, 3, 0x1000, false, 6, 0, + lfExactlyZero); + Helper::runTest(true, false, 7, 0x100, false, 3, 0x1001, true, 6, 0, + lfLessThanHalf); + // RHS significand greater after shitfing + Helper::runTest(true, false, 7, 0x100, false, 3, 0x10000, true, 6, 0x1e00, + lfExactlyZero); + Helper::runTest(true, false, 7, 0x100, false, 3, 0x10001, true, 6, 0x1e00, + lfLessThanHalf); + + // RHS larger exponent + // RHS significand greater after shitfing + Helper::runTest(true, false, 3, 0x100, false, 7, 0x100, true, 6, 0x1e0, + lfExactlyZero); + Helper::runTest(true, false, 3, 0x101, false, 7, 0x100, true, 6, 0x1df, + lfMoreThanHalf); + // Significands equal after shitfing + Helper::runTest(true, false, 3, 0x1000, false, 7, 0x100, false, 6, 0, + lfExactlyZero); + Helper::runTest(true, false, 3, 0x1001, false, 7, 0x100, false, 6, 0, + lfLessThanHalf); + // LHS significand greater after shitfing + Helper::runTest(true, false, 3, 0x10000, false, 7, 0x100, false, 6, 0x1e00, + lfExactlyZero); + Helper::runTest(true, false, 3, 0x10001, false, 7, 0x100, false, 6, 0x1e00, + lfLessThanHalf); +} } // namespace