diff --git a/cpp/src/arrow/array/statistics.h b/cpp/src/arrow/array/statistics.h index 5380debe3b6..6accd48af77 100644 --- a/cpp/src/arrow/array/statistics.h +++ b/cpp/src/arrow/array/statistics.h @@ -22,6 +22,7 @@ #include #include +#include "arrow/compare.h" #include "arrow/type.h" #include "arrow/util/visibility.h" @@ -127,11 +128,17 @@ struct ARROW_EXPORT ArrayStatistics { /// \brief Whether the maximum value is exact or not bool is_max_exact = false; - /// \brief Check two statistics for equality - bool Equals(const ArrayStatistics& other) const { - return null_count == other.null_count && distinct_count == other.distinct_count && - min == other.min && is_min_exact == other.is_min_exact && max == other.max && - is_max_exact == other.is_max_exact; + /// \brief Check two \ref arrow::ArrayStatistics for equality + /// + /// \param other The \ref arrow::ArrayStatistics instance to compare against. + /// + /// \param equal_options Options used to compare double values for equality. + /// + /// \return True if the two \ref arrow::ArrayStatistics instances are equal; otherwise, + /// false. + bool Equals(const ArrayStatistics& other, + const EqualOptions& equal_options = EqualOptions::Defaults()) const { + return ArrayStatisticsEquals(*this, other, equal_options); } /// \brief Check two statistics for equality diff --git a/cpp/src/arrow/array/statistics_test.cc b/cpp/src/arrow/array/statistics_test.cc index cf15a5d3829..95199a9683b 100644 --- a/cpp/src/arrow/array/statistics_test.cc +++ b/cpp/src/arrow/array/statistics_test.cc @@ -15,13 +15,17 @@ // specific language governing permissions and limitations // under the License. +#include +#include + #include #include "arrow/array/statistics.h" +#include "arrow/compare.h" namespace arrow { -TEST(ArrayStatisticsTest, TestNullCount) { +TEST(TestArrayStatistics, NullCount) { ArrayStatistics statistics; ASSERT_FALSE(statistics.null_count.has_value()); statistics.null_count = 29; @@ -29,7 +33,7 @@ TEST(ArrayStatisticsTest, TestNullCount) { ASSERT_EQ(29, statistics.null_count.value()); } -TEST(ArrayStatisticsTest, TestDistinctCount) { +TEST(TestArrayStatistics, DistinctCount) { ArrayStatistics statistics; ASSERT_FALSE(statistics.distinct_count.has_value()); statistics.distinct_count = 29; @@ -37,7 +41,7 @@ TEST(ArrayStatisticsTest, TestDistinctCount) { ASSERT_EQ(29, statistics.distinct_count.value()); } -TEST(ArrayStatisticsTest, TestMin) { +TEST(TestArrayStatistics, Min) { ArrayStatistics statistics; ASSERT_FALSE(statistics.min.has_value()); ASSERT_FALSE(statistics.is_min_exact); @@ -49,7 +53,7 @@ TEST(ArrayStatisticsTest, TestMin) { ASSERT_TRUE(statistics.is_min_exact); } -TEST(ArrayStatisticsTest, TestMax) { +TEST(TestArrayStatistics, Max) { ArrayStatistics statistics; ASSERT_FALSE(statistics.max.has_value()); ASSERT_FALSE(statistics.is_max_exact); @@ -61,7 +65,7 @@ TEST(ArrayStatisticsTest, TestMax) { ASSERT_FALSE(statistics.is_max_exact); } -TEST(ArrayStatisticsTest, TestEquality) { +TEST(TestArrayStatistics, EqualityNonDoulbeValue) { ArrayStatistics statistics1; ArrayStatistics statistics2; @@ -96,6 +100,56 @@ TEST(ArrayStatisticsTest, TestEquality) { ASSERT_NE(statistics1, statistics2); statistics2.is_max_exact = true; ASSERT_EQ(statistics1, statistics2); + + // Test different ArrayStatistics::ValueType + statistics1.max = static_cast(29); + statistics1.max = static_cast(29); + ASSERT_NE(statistics1, statistics2); +} + +class TestArrayStatisticsEqualityDoubleValue : public ::testing::Test { + protected: + ArrayStatistics statistics1_; + ArrayStatistics statistics2_; + EqualOptions options_ = EqualOptions::Defaults(); +}; + +TEST_F(TestArrayStatisticsEqualityDoubleValue, ExactValue) { + statistics2_.min = 29.0; + statistics1_.min = 29.0; + ASSERT_EQ(statistics1_, statistics2_); + statistics2_.min = 30.0; + ASSERT_NE(statistics1_, statistics2_); +} + +TEST_F(TestArrayStatisticsEqualityDoubleValue, SignedZero) { + statistics1_.min = +0.0; + statistics2_.min = -0.0; + ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.signed_zeros_equal(true))); + ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.signed_zeros_equal(false))); +} + +TEST_F(TestArrayStatisticsEqualityDoubleValue, Infinity) { + auto infinity = std::numeric_limits::infinity(); + statistics1_.min = infinity; + statistics2_.min = infinity; + ASSERT_EQ(statistics1_, statistics2_); + statistics1_.min = -infinity; + ASSERT_NE(statistics1_, statistics2_); +} + +TEST_F(TestArrayStatisticsEqualityDoubleValue, NaN) { + statistics1_.min = std::numeric_limits::quiet_NaN(); + statistics2_.min = std::numeric_limits::quiet_NaN(); + ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.nans_equal(true))); + ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.nans_equal(false))); +} + +TEST_F(TestArrayStatisticsEqualityDoubleValue, ApproximateEquals) { + statistics1_.max = 0.5001f; + statistics2_.max = 0.5; + ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.atol(1e-3).use_atol(false))); + ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.atol(1e-3))); } } // namespace arrow diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index 3b64a8fd09f..2460afbf87c 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -24,13 +24,16 @@ #include #include #include +#include #include #include #include +#include #include #include "arrow/array.h" #include "arrow/array/diff.h" +#include "arrow/array/statistics.h" #include "arrow/buffer.h" #include "arrow/scalar.h" #include "arrow/sparse_tensor.h" @@ -1523,4 +1526,55 @@ bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata } } +namespace { + +bool DoubleEquals(const double& left, const double& right, const EqualOptions& options) { + bool result; + auto visitor = [&](auto&& compare_func) { result = compare_func(left, right); }; + VisitFloatingEquality(options, options.use_atol(), std::move(visitor)); + return result; +} + +bool ArrayStatisticsValueTypeEquals( + const std::optional& left, + const std::optional& right, const EqualOptions& options) { + if (!left.has_value() || !right.has_value()) { + return left.has_value() == right.has_value(); + } else if (left->index() != right->index()) { + return false; + } else { + auto EqualsVisitor = [&](const auto& v1, const auto& v2) { + using type_1 = std::decay_t; + using type_2 = std::decay_t; + if constexpr (std::conjunction_v, + std::is_same>) { + return DoubleEquals(v1, v2, options); + } else if constexpr (std::is_same_v) { + return v1 == v2; + } + // It is unreachable + DCHECK(false); + return false; + }; + return std::visit(EqualsVisitor, left.value(), right.value()); + } +} + +bool ArrayStatisticsEqualsImpl(const ArrayStatistics& left, const ArrayStatistics& right, + const EqualOptions& equal_options) { + return left.null_count == right.null_count && + left.distinct_count == right.distinct_count && + left.is_min_exact == right.is_min_exact && + left.is_max_exact == right.is_max_exact && + ArrayStatisticsValueTypeEquals(left.min, right.min, equal_options) && + ArrayStatisticsValueTypeEquals(left.max, right.max, equal_options); +} + +} // namespace + +bool ArrayStatisticsEquals(const ArrayStatistics& left, const ArrayStatistics& right, + const EqualOptions& options) { + return ArrayStatisticsEqualsImpl(left, right, options); +} + } // namespace arrow diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h index 6b365c59913..ec7dc8bda18 100644 --- a/cpp/src/arrow/compare.h +++ b/cpp/src/arrow/compare.h @@ -27,6 +27,7 @@ namespace arrow { +struct ArrayStatistics; class Array; class DataType; class Tensor; @@ -58,7 +59,18 @@ class EqualOptions { return res; } + /// Whether the "atol" property is used in the comparison. + bool use_atol() const { return use_atol_; } + + /// Return a new EqualOptions object with the "use_atol" property changed. + EqualOptions use_atol(bool v) const { + auto res = EqualOptions(*this); + res.use_atol_ = v; + return res; + } + /// The absolute tolerance for approximate comparisons of floating-point values. + /// Note that this option is ignored if "use_atol" is set to false. double atol() const { return atol_; } /// Return a new EqualOptions object with the "atol" property changed. @@ -87,6 +99,7 @@ class EqualOptions { double atol_ = kDefaultAbsoluteTolerance; bool nans_equal_ = false; bool signed_zeros_equal_ = true; + bool use_atol_ = true; std::ostream* diff_sink_ = NULLPTR; }; @@ -135,6 +148,16 @@ ARROW_EXPORT bool SparseTensorEquals(const SparseTensor& left, const SparseTenso ARROW_EXPORT bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata = true); +/// \brief Check two \ref arrow::ArrayStatistics for equality +/// \param[in] left an \ref arrow::ArrayStatistics +/// \param[in] right an \ref arrow::ArrayStatistics +/// \param[in] options Options used to compare double values for equality. +/// \return True if the two \ref arrow::ArrayStatistics instances are equal; otherwise, +/// false. +ARROW_EXPORT bool ArrayStatisticsEquals( + const ArrayStatistics& left, const ArrayStatistics& right, + const EqualOptions& options = EqualOptions::Defaults()); + /// Returns true if scalars are equal /// \param[in] left a Scalar /// \param[in] right a Scalar