From 5483f59bb6206da7d1f352597c2d87a6a84fef27 Mon Sep 17 00:00:00 2001 From: "sagarm@fb.com" Date: Tue, 10 Aug 2021 14:31:12 -0700 Subject: [PATCH 1/2] Implement Spark's variadic HASH() function Summary: Spark uses a variant of Murmur3; I've copied the hash implementation directly from the Spark codebase, modified of course for C++. The Spark algorithm for combining hashes is simple: initialize the hash to 42, and then use the current hash as the seed for the next hash. NULLs are skipped. This implementation processes one column at a time. Differential Revision: D29743513 fbshipit-source-id: c55c300ff824e6fc5ff15515393a02116b7dba92 --- velox/functions/sparksql/CMakeLists.txt | 1 + velox/functions/sparksql/Hash.cpp | 183 ++++++++++++++++++ velox/functions/sparksql/Hash.h | 38 ++++ velox/functions/sparksql/Register.cpp | 7 +- velox/functions/sparksql/tests/CMakeLists.txt | 2 +- velox/functions/sparksql/tests/HashTest.cpp | 122 ++++++++++++ 6 files changed, 349 insertions(+), 4 deletions(-) create mode 100644 velox/functions/sparksql/Hash.cpp create mode 100644 velox/functions/sparksql/Hash.h create mode 100644 velox/functions/sparksql/tests/HashTest.cpp diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index 651a0b7e34f5..6bdb3fd4899e 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -11,6 +11,7 @@ # limitations under the License. add_library( velox_functions_spark OBJECT + Hash.cpp LeastGreatest.cpp RegexFunctions.cpp Register.cpp diff --git a/velox/functions/sparksql/Hash.cpp b/velox/functions/sparksql/Hash.cpp new file mode 100644 index 000000000000..d9e01e4e7241 --- /dev/null +++ b/velox/functions/sparksql/Hash.cpp @@ -0,0 +1,183 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/sparksql/Hash.h" + +#include +#include + +#include + +#include "velox/vector/FlatVector.h" + +namespace facebook::velox::functions::sparksql { +namespace { + +// Derived from src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java. +// +// Spark's Murmur3 seems slightly different from the original from Austin +// Appleby: in particular the fmix function's first line is different. The +// original can be found here: +// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp +// +// Signed integer types have been remapped to unsigned types (as in the +// original) to avoid undefined signed integer overflow and sign extension. + +uint32_t mixK1(uint32_t k1) { + k1 *= 0xcc9e2d51; + k1 = _rotl(k1, 15); + k1 *= 0x1b873593; + return k1; +} + +uint32_t mixH1(uint32_t h1, uint32_t k1) { + h1 ^= k1; + h1 = _rotl(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; +} + +// Finalization mix - force all bits of a hash block to avalanche +uint32_t fmix(uint32_t h1, uint32_t length) { + h1 ^= length; + h1 ^= h1 >> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >> 16; + return h1; +} + +uint32_t hashInt32(int32_t input, uint32_t seed) { + uint32_t k1 = mixK1(input); + uint32_t h1 = mixH1(seed, k1); + return fmix(h1, 4); +} + +uint32_t hashInt64(uint64_t input, uint32_t seed) { + uint32_t low = input; + uint32_t high = input >> 32; + + uint32_t k1 = mixK1(low); + uint32_t h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); +} + +// Spark also has an hashUnsafeBytes2 function, but it was not used at the time +// of implementation. +uint32_t hashBytes(const StringView& input, uint32_t seed) { + const char* i = input.data(); + const char* const end = input.data() + input.size(); + uint32_t h1 = seed; + for (; i <= end - 4; i += 4) { + h1 = mixH1(h1, mixK1(*reinterpret_cast(i))); + } + for (; i != end; ++i) { + h1 = mixH1(h1, mixK1(*i)); + } + return fmix(h1, input.size()); +} + +// Floating point numbers are hashed as if they are integers, with +// -0f defined to have the same output as +0f. +uint32_t hashFloat(float input, uint32_t seed) { + return hashInt32( + input == -0.f ? 0 : *reinterpret_cast(&input), seed); +} + +uint32_t hashDouble(double input, uint32_t seed) { + return hashInt64( + input == -0. ? 0 : *reinterpret_cast(&input), seed); +} + +class HashFunction final : public exec::VectorFunction { + bool isDefaultNullBehavior() const final { + return false; + } + + void apply( + const SelectivityVector& rows, + std::vector& args, // Not using const ref so we can reuse args + exec::Expr* caller, + exec::EvalCtx* context, + VectorPtr* resultRef) const final { + constexpr int32_t kSeed = 42; + + BaseVector::ensureWritable(rows, INTEGER(), context->pool(), resultRef); + + FlatVector& result = *(*resultRef)->as>(); + rows.applyToSelected([&](int row) { result.set(row, kSeed); }); + + std::optional cached_local_selected; + + for (auto& arg : args) { + exec::LocalDecodedVector decoded(context, *arg, rows); + const SelectivityVector* selected = &rows; + if (arg->mayHaveNulls()) { + if (!cached_local_selected) { + cached_local_selected.emplace(context, rows.end()); + } + *cached_local_selected->get() = rows; + cached_local_selected->get()->deselectNulls( + arg->flatRawNulls(rows), rows.begin(), rows.end()); + selected = cached_local_selected->get(); + } + switch (arg->type()->kind()) { +#define CASE(typeEnum, hashFn, inputType) \ + case TypeKind::typeEnum: \ + selected->applyToSelected([&](int row) { \ + result.set( \ + row, hashFn(decoded->valueAt(row), result.valueAt(row))); \ + }); \ + break; + // Derived from InterpretedHashFunction.hash: + // https://github.com/apache/spark/blob/382b66e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala#L532 + CASE(BOOLEAN, hashInt32, bool); + CASE(TINYINT, hashInt32, int8_t); + CASE(SMALLINT, hashInt32, int16_t); + CASE(INTEGER, hashInt32, int32_t); + CASE(BIGINT, hashInt64, int64_t); + CASE(VARCHAR, hashBytes, StringView); + CASE(VARBINARY, hashBytes, StringView); + CASE(REAL, hashFloat, float); + CASE(DOUBLE, hashDouble, double); +#undef CASE + default: + VELOX_NYI("Unsupported type for HASH(): {}", arg->type()->toString()); + } + } + } +}; + +} // namespace + +std::vector> hashSignatures() { + return {exec::FunctionSignatureBuilder() + .returnType("integer") + .argumentType("any") + .variableArity() + .build()}; +} + +std::shared_ptr makeHash( + const std::string& name, + const std::vector& inputArgs) { + static const auto kHashFunction = std::make_shared(); + return kHashFunction; +} + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Hash.h b/velox/functions/sparksql/Hash.h new file mode 100644 index 000000000000..52e6f7d0b1fa --- /dev/null +++ b/velox/functions/sparksql/Hash.h @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/VectorFunction.h" + +namespace facebook::velox::functions::sparksql { + +// Supported types: +// - Bools +// - Integer types (byte, short, int, long) +// - String, Binary +// - Float, Double +// +// Unsupported: +// - Decimal +// - Datetime +// - Structs, Arrays: hash the elements in order +// - Maps: iterate over map, hashing key then value. Since map ordering is +// unspecified, hashing logically equivalent maps may result in +// different hash values. + +std::vector> hashSignatures(); + +std::shared_ptr makeHash( + const std::string& name, + const std::vector& inputArgs); + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp index 1ac3d04491d5..c8cb0ccadb87 100644 --- a/velox/functions/sparksql/Register.cpp +++ b/velox/functions/sparksql/Register.cpp @@ -14,12 +14,12 @@ #include "velox/functions/sparksql/Register.h" #include "velox/functions/common/DateTimeFunctions.h" -#include "velox/functions/common/Hash.h" #include "velox/functions/common/JsonExtractScalar.h" #include "velox/functions/common/Rand.h" #include "velox/functions/common/StringFunctions.h" #include "velox/functions/lib/Re2Functions.h" #include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/sparksql/Hash.h" #include "velox/functions/sparksql/LeastGreatest.h" #include "velox/functions/sparksql/RegexFunctions.h" #include "velox/functions/sparksql/RegisterArithmetic.h" @@ -56,8 +56,6 @@ namespace sparksql { void registerFunctions(const std::string& prefix) { registerFunction({"rand"}); - registerUnaryScalar({"hash"}); - registerFunction( {prefix + "get_json_object"}); @@ -82,6 +80,7 @@ void registerFunctions(const std::string& prefix) { registerFunction, Varchar, Varchar>({"md5"}); VELOX_REGISTER_VECTOR_FUNCTION(udf_subscript, prefix + "subscript"); VELOX_REGISTER_VECTOR_FUNCTION(udf_regexp_split, prefix + "split"); + exec::registerStatefulVectorFunction( "regexp_extract", re2ExtractSignatures(), makeRegexExtract); exec::registerStatefulVectorFunction( @@ -96,6 +95,8 @@ void registerFunctions(const std::string& prefix) { prefix + "least", leastSignatures(), makeLeast); exec::registerStatefulVectorFunction( prefix + "greatest", greatestSignatures(), makeGreatest); + exec::registerStatefulVectorFunction( + prefix + "hash", hashSignatures(), makeHash); // These vector functions are only accessible via the // VELOX_REGISTER_VECTOR_FUNCTION macro, which must be invoked in the same // namespace as the function definition. diff --git a/velox/functions/sparksql/tests/CMakeLists.txt b/velox/functions/sparksql/tests/CMakeLists.txt index e00071c93125..d6f58387571d 100644 --- a/velox/functions/sparksql/tests/CMakeLists.txt +++ b/velox/functions/sparksql/tests/CMakeLists.txt @@ -12,7 +12,7 @@ add_executable( velox_functions_spark_test - ArithmeticTest.cpp LeastGreatestTest.cpp RegexFunctionsTest.cpp + ArithmeticTest.cpp HashTest.cpp LeastGreatestTest.cpp RegexFunctionsTest.cpp SplitFunctionsTest.cpp SubscriptTest.cpp) add_test(velox_functions_spark_test velox_functions_spark_test) diff --git a/velox/functions/sparksql/tests/HashTest.cpp b/velox/functions/sparksql/tests/HashTest.cpp new file mode 100644 index 000000000000..c443579d2be0 --- /dev/null +++ b/velox/functions/sparksql/tests/HashTest.cpp @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" + +#include + +#include "velox/functions/sparksql/Hash.h" + +namespace facebook::velox::functions::sparksql::test { +namespace { + +class HashTest : public SparkFunctionBaseTest { + public: + static void SetUpTestCase() { + exec::registerStatefulVectorFunction("hash", hashSignatures(), makeHash); + } + + protected: + template + std::optional hash(std::optional arg) { + return evaluateOnce("hash(c0)", arg); + } +}; + +TEST_F(HashTest, String) { + EXPECT_EQ(hash("Spark"), 228093765); + EXPECT_EQ(hash(""), 142593372); + EXPECT_EQ(hash("abcdefghijklmnopqrstuvwxyz"), -1990933474); + // String that has a length that is a multiple of four. + EXPECT_EQ(hash("12345678"), 2036199019); + EXPECT_EQ(hash(std::nullopt), 42); +} + +TEST_F(HashTest, Int64) { + EXPECT_EQ(hash(0xcafecafedeadbeef), -256235155); + EXPECT_EQ(hash(0xdeadbeefcafecafe), 673261790); + EXPECT_EQ(hash(INT64_MAX), -1604625029); + EXPECT_EQ(hash(INT64_MIN), -853646085); + EXPECT_EQ(hash(1), -1712319331); + EXPECT_EQ(hash(0), -1670924195); + EXPECT_EQ(hash(-1), -939490007); + EXPECT_EQ(hash(std::nullopt), 42); +} + +TEST_F(HashTest, Int32) { + EXPECT_EQ(hash(0xdeadbeef), 141248195); + EXPECT_EQ(hash(0xcafecafe), 638354558); + EXPECT_EQ(hash(1), -559580957); + EXPECT_EQ(hash(0), 933211791); + EXPECT_EQ(hash(-1), -1604776387); + EXPECT_EQ(hash(std::nullopt), 42); +} + +TEST_F(HashTest, Int16) { + EXPECT_EQ(hash(1), -559580957); + EXPECT_EQ(hash(0), 933211791); + EXPECT_EQ(hash(-1), -1604776387); + EXPECT_EQ(hash(std::nullopt), 42); +} + +TEST_F(HashTest, Int8) { + EXPECT_EQ(hash(1), -559580957); + EXPECT_EQ(hash(0), 933211791); + EXPECT_EQ(hash(-1), -1604776387); + EXPECT_EQ(hash(std::nullopt), 42); +} + +TEST_F(HashTest, Bool) { + EXPECT_EQ(hash(false), 933211791); + EXPECT_EQ(hash(true), -559580957); + EXPECT_EQ(hash(std::nullopt), 42); +} + +TEST_F(HashTest, StringInt32) { + auto hash = [&](std::optional a, std::optional b) { + return evaluateOnce("hash(c0, c1)", a, b); + }; + + EXPECT_EQ(hash(std::nullopt, std::nullopt), 42); + EXPECT_EQ(hash("", std::nullopt), 142593372); + EXPECT_EQ(hash(std::nullopt, 0), 933211791); + EXPECT_EQ(hash("", 0), 1143746540); +} + +TEST_F(HashTest, Double) { + using limits = std::numeric_limits; + + EXPECT_EQ(hash(std::nullopt), 42); + EXPECT_EQ(hash(-0.0), -1670924195); + EXPECT_EQ(hash(0), -1670924195); + EXPECT_EQ(hash(1), -460888942); + EXPECT_EQ(hash(limits::quiet_NaN()), -1281358385); + EXPECT_EQ(hash(limits::infinity()), 833680482); + EXPECT_EQ(hash(-limits::infinity()), 461104036); +} + +TEST_F(HashTest, Float) { + using limits = std::numeric_limits; + + EXPECT_EQ(hash(std::nullopt), 42); + EXPECT_EQ(hash(-0.0f), 933211791); + EXPECT_EQ(hash(0), 933211791); + EXPECT_EQ(hash(1), -466301895); + EXPECT_EQ(hash(limits::quiet_NaN()), -349261430); + EXPECT_EQ(hash(limits::infinity()), 2026854605); + EXPECT_EQ(hash(-limits::infinity()), 427440766); +} + +} // namespace +} // namespace facebook::velox::functions::sparksql::test From cbd96cda4ef41e7033ec885020ed2d7d12a15809 Mon Sep 17 00:00:00 2001 From: Sagar Mittal Date: Tue, 10 Aug 2021 14:31:24 -0700 Subject: [PATCH 2/2] Don't depend on "common" registrations in sparksql directory Summary: Function registration is moved from FunctionBaseTest's constructor to SetUpTestCase(), which is once-per-process. Unittests were updated to rely on the SparkSQL registration functions, rather than do their own registration. Reviewed By: pedroerp Differential Revision: D30061756 fbshipit-source-id: f984329c8d93098ba9146d357b878848435942c0 --- velox/functions/common/tests/FunctionBaseTest.cpp | 2 +- velox/functions/common/tests/FunctionBaseTest.h | 2 +- velox/functions/lib/tests/Re2FunctionsTest.cpp | 2 ++ velox/functions/sparksql/tests/HashTest.cpp | 5 ----- velox/functions/sparksql/tests/RegexFunctionsTest.cpp | 7 ------- velox/functions/sparksql/tests/SparkFunctionBaseTest.h | 7 +++++-- 6 files changed, 9 insertions(+), 16 deletions(-) diff --git a/velox/functions/common/tests/FunctionBaseTest.cpp b/velox/functions/common/tests/FunctionBaseTest.cpp index d72b3a758d85..b2ae40a53710 100644 --- a/velox/functions/common/tests/FunctionBaseTest.cpp +++ b/velox/functions/common/tests/FunctionBaseTest.cpp @@ -17,7 +17,7 @@ #include "velox/functions/common/VectorFunctions.h" namespace facebook::velox::functions::test { -FunctionBaseTest::FunctionBaseTest() { +void FunctionBaseTest::SetUpTestCase() { exec::test::registerTypeResolver(); functions::registerFunctions(); functions::registerVectorFunctions(); diff --git a/velox/functions/common/tests/FunctionBaseTest.h b/velox/functions/common/tests/FunctionBaseTest.h index 44f7e1db1ad2..7a343c66b1ef 100644 --- a/velox/functions/common/tests/FunctionBaseTest.h +++ b/velox/functions/common/tests/FunctionBaseTest.h @@ -24,7 +24,7 @@ namespace facebook::velox::functions::test { class FunctionBaseTest : public testing::Test { protected: - FunctionBaseTest(); + static void SetUpTestCase(); template using EvalType = typename CppToType::NativeType; diff --git a/velox/functions/lib/tests/Re2FunctionsTest.cpp b/velox/functions/lib/tests/Re2FunctionsTest.cpp index 05f9195a7627..915cdabe408f 100644 --- a/velox/functions/lib/tests/Re2FunctionsTest.cpp +++ b/velox/functions/lib/tests/Re2FunctionsTest.cpp @@ -18,6 +18,7 @@ #include #include +#include "velox/exec/tests/utils/FunctionUtils.h" #include "velox/functions/common/tests/FunctionBaseTest.h" namespace facebook::velox::functions { @@ -26,6 +27,7 @@ namespace { class Re2FunctionsTest : public test::FunctionBaseTest { public: static void SetUpTestCase() { + exec::test::registerTypeResolver(); exec::registerStatefulVectorFunction( "re2_match", re2MatchSignatures(), makeRe2Match); exec::registerStatefulVectorFunction( diff --git a/velox/functions/sparksql/tests/HashTest.cpp b/velox/functions/sparksql/tests/HashTest.cpp index c443579d2be0..75d339505000 100644 --- a/velox/functions/sparksql/tests/HashTest.cpp +++ b/velox/functions/sparksql/tests/HashTest.cpp @@ -22,11 +22,6 @@ namespace facebook::velox::functions::sparksql::test { namespace { class HashTest : public SparkFunctionBaseTest { - public: - static void SetUpTestCase() { - exec::registerStatefulVectorFunction("hash", hashSignatures(), makeHash); - } - protected: template std::optional hash(std::optional arg) { diff --git a/velox/functions/sparksql/tests/RegexFunctionsTest.cpp b/velox/functions/sparksql/tests/RegexFunctionsTest.cpp index 0f1470d69a16..d3a31679863a 100644 --- a/velox/functions/sparksql/tests/RegexFunctionsTest.cpp +++ b/velox/functions/sparksql/tests/RegexFunctionsTest.cpp @@ -27,13 +27,6 @@ namespace { class RegexFunctionsTest : public test::SparkFunctionBaseTest { public: - static void SetUpTestCase() { - exec::registerStatefulVectorFunction( - "rlike", re2MatchSignatures(), makeRLike); - exec::registerStatefulVectorFunction( - "regexp_extract", re2ExtractSignatures(), makeRegexExtract); - } - std::optional rlike( std::optional str, std::string pattern) { diff --git a/velox/functions/sparksql/tests/SparkFunctionBaseTest.h b/velox/functions/sparksql/tests/SparkFunctionBaseTest.h index 43f3ad84cd65..753777507e88 100644 --- a/velox/functions/sparksql/tests/SparkFunctionBaseTest.h +++ b/velox/functions/sparksql/tests/SparkFunctionBaseTest.h @@ -13,6 +13,7 @@ */ #pragma once +#include "velox/exec/tests/utils/FunctionUtils.h" #include "velox/functions/common/tests/FunctionBaseTest.h" #include "velox/functions/sparksql/Register.h" @@ -22,8 +23,10 @@ using facebook::velox::functions::test::FunctionBaseTest; class SparkFunctionBaseTest : public FunctionBaseTest { protected: - // Ensure Spark functions are registered. - SparkFunctionBaseTest() { + // Ensure Spark functions are registered; don't register the "common" + // (CoreSQL) functions. + static void SetUpTestCase() { + exec::test::registerTypeResolver(); sparksql::registerFunctions(""); } };