From 8ebaed1ad8395d0db4c0bbf63926a30819105579 Mon Sep 17 00:00:00 2001 From: PHILO-HE Date: Wed, 29 Mar 2023 11:20:51 +0800 Subject: [PATCH] Support float & double types in pmod function (#157) --- velox/functions/sparksql/Arithmetic.h | 33 +++++++++++++++++++ .../functions/sparksql/RegisterArithmetic.cpp | 3 +- .../sparksql/tests/ArithmeticTest.cpp | 12 +++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/velox/functions/sparksql/Arithmetic.h b/velox/functions/sparksql/Arithmetic.h index 4778cd081a94..a108a4ce23dc 100644 --- a/velox/functions/sparksql/Arithmetic.h +++ b/velox/functions/sparksql/Arithmetic.h @@ -24,6 +24,39 @@ namespace facebook::velox::functions::sparksql { +template +struct PModIntFunction { + template + FOLLY_ALWAYS_INLINE bool call(TInput& result, const TInput a, const TInput n) +#if defined(__has_feature) +#if __has_feature(__address_sanitizer__) + __attribute__((__no_sanitize__("signed-integer-overflow"))) +#endif +#endif + { + if (UNLIKELY(n == 0)) { + return false; + } + TInput r = a % n; + result = (r > 0) ? r : (r + n) % n; + return true; + } +}; + +template +struct PModFloatFunction { + template + FOLLY_ALWAYS_INLINE bool + call(TInput& result, const TInput a, const TInput n) { + if (UNLIKELY(n == (TInput)0)) { + return false; + } + TInput r = fmod(a, n); + result = (r > 0) ? r : fmod(r + n, n); + return true; + } +}; + template struct RemainderFunction { template diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/RegisterArithmetic.cpp index 4f72fdf59c30..a17c66cc79ec 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/RegisterArithmetic.cpp @@ -32,7 +32,8 @@ void registerArithmeticFunctions(const std::string& prefix) { // Math functions. registerUnaryNumeric({prefix + "abs"}); registerFunction({prefix + "exp"}); - registerBinaryIntegral({prefix + "pmod"}); + registerBinaryIntegral({prefix + "pmod"}); + registerBinaryFloatingPoint({prefix + "pmod"}); registerFunction({prefix + "power"}); registerUnaryNumeric({prefix + "round"}); registerFunction({prefix + "round"}); diff --git a/velox/functions/sparksql/tests/ArithmeticTest.cpp b/velox/functions/sparksql/tests/ArithmeticTest.cpp index 80564aa4d112..3832f41c48c1 100644 --- a/velox/functions/sparksql/tests/ArithmeticTest.cpp +++ b/velox/functions/sparksql/tests/ArithmeticTest.cpp @@ -71,6 +71,18 @@ TEST_F(PmodTest, int64) { EXPECT_EQ(0, pmod(INT64_MIN, -1)); } +TEST_F(PmodTest, float) { + EXPECT_FLOAT_EQ(0.2, pmod(0.5, 0.3).value()); + EXPECT_FLOAT_EQ(0.9, pmod(-1.1, 2).value()); + EXPECT_EQ(std::nullopt, pmod(2.14159, 0.0)); +} + +TEST_F(PmodTest, double) { + EXPECT_DOUBLE_EQ(0.2, pmod(0.5, 0.3).value()); + EXPECT_DOUBLE_EQ(0.9, pmod(-1.1, 2).value()); + EXPECT_EQ(std::nullopt, pmod(2.14159, 0.0)); +} + class RemainderTest : public SparkFunctionBaseTest { protected: template