Skip to content

Commit

Permalink
Support float & double types in pmod function (facebookincubator#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE authored and zhejiangxiaomai committed Apr 20, 2023
1 parent 1430cb9 commit 8ebaed1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
33 changes: 33 additions & 0 deletions velox/functions/sparksql/Arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,39 @@

namespace facebook::velox::functions::sparksql {

template <typename T>
struct PModIntFunction {
template <typename TInput>
FOLLY_ALWAYS_INLINE bool call(TInput& result, const TInput a, const TInput n)
#if defined(__has_feature)
#if __has_feature(__address_sanitizer__)
__attribute__((__no_sanitize__("signed-integer-overflow")))
#endif
#endif
{
if (UNLIKELY(n == 0)) {
return false;
}
TInput r = a % n;
result = (r > 0) ? r : (r + n) % n;
return true;
}
};

template <typename T>
struct PModFloatFunction {
template <typename TInput>
FOLLY_ALWAYS_INLINE bool
call(TInput& result, const TInput a, const TInput n) {
if (UNLIKELY(n == (TInput)0)) {
return false;
}
TInput r = fmod(a, n);
result = (r > 0) ? r : fmod(r + n, n);
return true;
}
};

template <typename T>
struct RemainderFunction {
template <typename TInput>
Expand Down
3 changes: 2 additions & 1 deletion velox/functions/sparksql/RegisterArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ void registerArithmeticFunctions(const std::string& prefix) {
// Math functions.
registerUnaryNumeric<AbsFunction>({prefix + "abs"});
registerFunction<ExpFunction, double, double>({prefix + "exp"});
registerBinaryIntegral<PModFunction>({prefix + "pmod"});
registerBinaryIntegral<PModIntFunction>({prefix + "pmod"});
registerBinaryFloatingPoint<PModFloatFunction>({prefix + "pmod"});
registerFunction<PowerFunction, double, double, double>({prefix + "power"});
registerUnaryNumeric<RoundFunction>({prefix + "round"});
registerFunction<RoundFunction, int8_t, int8_t, int32_t>({prefix + "round"});
Expand Down
12 changes: 12 additions & 0 deletions velox/functions/sparksql/tests/ArithmeticTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ TEST_F(PmodTest, int64) {
EXPECT_EQ(0, pmod<int64_t>(INT64_MIN, -1));
}

TEST_F(PmodTest, float) {
EXPECT_FLOAT_EQ(0.2, pmod<float>(0.5, 0.3).value());
EXPECT_FLOAT_EQ(0.9, pmod<float>(-1.1, 2).value());
EXPECT_EQ(std::nullopt, pmod<float>(2.14159, 0.0));
}

TEST_F(PmodTest, double) {
EXPECT_DOUBLE_EQ(0.2, pmod<double>(0.5, 0.3).value());
EXPECT_DOUBLE_EQ(0.9, pmod<double>(-1.1, 2).value());
EXPECT_EQ(std::nullopt, pmod<double>(2.14159, 0.0));
}

class RemainderTest : public SparkFunctionBaseTest {
protected:
template <typename T>
Expand Down

0 comments on commit 8ebaed1

Please sign in to comment.