Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't depend on "common" registrations in sparksql directory #18

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion velox/functions/common/tests/FunctionBaseTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include "velox/functions/common/VectorFunctions.h"

namespace facebook::velox::functions::test {
FunctionBaseTest::FunctionBaseTest() {
void FunctionBaseTest::SetUpTestCase() {
exec::test::registerTypeResolver();
functions::registerFunctions();
functions::registerVectorFunctions();
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/common/tests/FunctionBaseTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace facebook::velox::functions::test {

class FunctionBaseTest : public testing::Test {
protected:
FunctionBaseTest();
static void SetUpTestCase();

template <typename T>
using EvalType = typename CppToType<T>::NativeType;
Expand Down
2 changes: 2 additions & 0 deletions velox/functions/lib/tests/Re2FunctionsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <functional>
#include <optional>

#include "velox/exec/tests/utils/FunctionUtils.h"
#include "velox/functions/common/tests/FunctionBaseTest.h"

namespace facebook::velox::functions {
Expand All @@ -26,6 +27,7 @@ namespace {
class Re2FunctionsTest : public test::FunctionBaseTest {
public:
static void SetUpTestCase() {
exec::test::registerTypeResolver();
exec::registerStatefulVectorFunction(
"re2_match", re2MatchSignatures(), makeRe2Match);
exec::registerStatefulVectorFunction(
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# limitations under the License.
add_library(
velox_functions_spark OBJECT
Hash.cpp
LeastGreatest.cpp
RegexFunctions.cpp
Register.cpp
Expand Down
183 changes: 183 additions & 0 deletions velox/functions/sparksql/Hash.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/Hash.h"

#include <stdint.h>
#include <x86intrin.h>

#include <folly/CPortability.h>

#include "velox/vector/FlatVector.h"

namespace facebook::velox::functions::sparksql {
namespace {

// Derived from src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java.
//
// Spark's Murmur3 seems slightly different from the original from Austin
// Appleby: in particular the fmix function's first line is different. The
// original can be found here:
// https://github.com/aappleby/smhasher/blob/master/src/MurmurHash3.cpp
//
// Signed integer types have been remapped to unsigned types (as in the
// original) to avoid undefined signed integer overflow and sign extension.

uint32_t mixK1(uint32_t k1) {
k1 *= 0xcc9e2d51;
k1 = _rotl(k1, 15);
k1 *= 0x1b873593;
return k1;
}

uint32_t mixH1(uint32_t h1, uint32_t k1) {
h1 ^= k1;
h1 = _rotl(h1, 13);
h1 = h1 * 5 + 0xe6546b64;
return h1;
}

// Finalization mix - force all bits of a hash block to avalanche
uint32_t fmix(uint32_t h1, uint32_t length) {
h1 ^= length;
h1 ^= h1 >> 16;
h1 *= 0x85ebca6b;
h1 ^= h1 >> 13;
h1 *= 0xc2b2ae35;
h1 ^= h1 >> 16;
return h1;
}

uint32_t hashInt32(int32_t input, uint32_t seed) {
uint32_t k1 = mixK1(input);
uint32_t h1 = mixH1(seed, k1);
return fmix(h1, 4);
}

uint32_t hashInt64(uint64_t input, uint32_t seed) {
uint32_t low = input;
uint32_t high = input >> 32;

uint32_t k1 = mixK1(low);
uint32_t h1 = mixH1(seed, k1);

k1 = mixK1(high);
h1 = mixH1(h1, k1);

return fmix(h1, 8);
}

// Spark also has an hashUnsafeBytes2 function, but it was not used at the time
// of implementation.
uint32_t hashBytes(const StringView& input, uint32_t seed) {
const char* i = input.data();
const char* const end = input.data() + input.size();
uint32_t h1 = seed;
for (; i <= end - 4; i += 4) {
h1 = mixH1(h1, mixK1(*reinterpret_cast<const uint32_t*>(i)));
}
for (; i != end; ++i) {
h1 = mixH1(h1, mixK1(*i));
}
return fmix(h1, input.size());
}

// Floating point numbers are hashed as if they are integers, with
// -0f defined to have the same output as +0f.
uint32_t hashFloat(float input, uint32_t seed) {
return hashInt32(
input == -0.f ? 0 : *reinterpret_cast<uint32_t*>(&input), seed);
}

uint32_t hashDouble(double input, uint32_t seed) {
return hashInt64(
input == -0. ? 0 : *reinterpret_cast<uint64_t*>(&input), seed);
}

class HashFunction final : public exec::VectorFunction {
bool isDefaultNullBehavior() const final {
return false;
}

void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args, // Not using const ref so we can reuse args
exec::Expr* caller,
exec::EvalCtx* context,
VectorPtr* resultRef) const final {
constexpr int32_t kSeed = 42;

BaseVector::ensureWritable(rows, INTEGER(), context->pool(), resultRef);

FlatVector<int32_t>& result = *(*resultRef)->as<FlatVector<int32_t>>();
rows.applyToSelected([&](int row) { result.set(row, kSeed); });

std::optional<exec::LocalSelectivityVector> cached_local_selected;

for (auto& arg : args) {
exec::LocalDecodedVector decoded(context, *arg, rows);
const SelectivityVector* selected = &rows;
if (arg->mayHaveNulls()) {
if (!cached_local_selected) {
cached_local_selected.emplace(context, rows.end());
}
*cached_local_selected->get() = rows;
cached_local_selected->get()->deselectNulls(
arg->flatRawNulls(rows), rows.begin(), rows.end());
selected = cached_local_selected->get();
}
switch (arg->type()->kind()) {
#define CASE(typeEnum, hashFn, inputType) \
case TypeKind::typeEnum: \
selected->applyToSelected([&](int row) { \
result.set( \
row, hashFn(decoded->valueAt<inputType>(row), result.valueAt(row))); \
}); \
break;
// Derived from InterpretedHashFunction.hash:
// https://github.com/apache/spark/blob/382b66e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala#L532
CASE(BOOLEAN, hashInt32, bool);
CASE(TINYINT, hashInt32, int8_t);
CASE(SMALLINT, hashInt32, int16_t);
CASE(INTEGER, hashInt32, int32_t);
CASE(BIGINT, hashInt64, int64_t);
CASE(VARCHAR, hashBytes, StringView);
CASE(VARBINARY, hashBytes, StringView);
CASE(REAL, hashFloat, float);
CASE(DOUBLE, hashDouble, double);
#undef CASE
default:
VELOX_NYI("Unsupported type for HASH(): {}", arg->type()->toString());
}
}
}
};

} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>> hashSignatures() {
return {exec::FunctionSignatureBuilder()
.returnType("integer")
.argumentType("any")
.variableArity()
.build()};
}

std::shared_ptr<exec::VectorFunction> makeHash(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs) {
static const auto kHashFunction = std::make_shared<HashFunction>();
return kHashFunction;
}

} // namespace facebook::velox::functions::sparksql
38 changes: 38 additions & 0 deletions velox/functions/sparksql/Hash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

// Supported types:
// - Bools
// - Integer types (byte, short, int, long)
// - String, Binary
// - Float, Double
//
// Unsupported:
// - Decimal
// - Datetime
// - Structs, Arrays: hash the elements in order
// - Maps: iterate over map, hashing key then value. Since map ordering is
// unspecified, hashing logically equivalent maps may result in
// different hash values.

std::vector<std::shared_ptr<exec::FunctionSignature>> hashSignatures();

std::shared_ptr<exec::VectorFunction> makeHash(
const std::string& name,
const std::vector<exec::VectorFunctionArg>& inputArgs);

} // namespace facebook::velox::functions::sparksql
7 changes: 4 additions & 3 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
#include "velox/functions/sparksql/Register.h"

#include "velox/functions/common/DateTimeFunctions.h"
#include "velox/functions/common/Hash.h"
#include "velox/functions/common/JsonExtractScalar.h"
#include "velox/functions/common/Rand.h"
#include "velox/functions/common/StringFunctions.h"
#include "velox/functions/lib/Re2Functions.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/sparksql/Hash.h"
#include "velox/functions/sparksql/LeastGreatest.h"
#include "velox/functions/sparksql/RegexFunctions.h"
#include "velox/functions/sparksql/RegisterArithmetic.h"
Expand Down Expand Up @@ -56,8 +56,6 @@ namespace sparksql {
void registerFunctions(const std::string& prefix) {
registerFunction<udf_rand, double>({"rand"});

registerUnaryScalar<udf_hash, int64_t>({"hash"});

registerFunction<udf_json_extract_scalar, Varchar, Varchar, Varchar>(
{prefix + "get_json_object"});

Expand All @@ -82,6 +80,7 @@ void registerFunctions(const std::string& prefix) {
registerFunction<udf_md5_radix<Varchar, Varchar>, Varchar, Varchar>({"md5"});
VELOX_REGISTER_VECTOR_FUNCTION(udf_subscript, prefix + "subscript");
VELOX_REGISTER_VECTOR_FUNCTION(udf_regexp_split, prefix + "split");

exec::registerStatefulVectorFunction(
"regexp_extract", re2ExtractSignatures(), makeRegexExtract);
exec::registerStatefulVectorFunction(
Expand All @@ -96,6 +95,8 @@ void registerFunctions(const std::string& prefix) {
prefix + "least", leastSignatures(), makeLeast);
exec::registerStatefulVectorFunction(
prefix + "greatest", greatestSignatures(), makeGreatest);
exec::registerStatefulVectorFunction(
prefix + "hash", hashSignatures(), makeHash);
// These vector functions are only accessible via the
// VELOX_REGISTER_VECTOR_FUNCTION macro, which must be invoked in the same
// namespace as the function definition.
Expand Down
2 changes: 1 addition & 1 deletion velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

add_executable(
velox_functions_spark_test
ArithmeticTest.cpp LeastGreatestTest.cpp RegexFunctionsTest.cpp
ArithmeticTest.cpp HashTest.cpp LeastGreatestTest.cpp RegexFunctionsTest.cpp
SplitFunctionsTest.cpp SubscriptTest.cpp)

add_test(velox_functions_spark_test velox_functions_spark_test)
Expand Down
Loading