From 2aa09d6238b4ce9e9d938954e30b56f6568c1466 Mon Sep 17 00:00:00 2001 From: Wey Gu Date: Mon, 13 Feb 2023 04:08:23 +0000 Subject: [PATCH] example of udf build added Co-authored-by: Wey Gu Co-authored-by: Cheng Xuntao <7731943+xtcyclist@users.noreply.github.com> --- conf/nebula-graphd.conf.default | 6 ++ src/common/function/FunctionManager.cpp | 2 +- src/common/function/FunctionUdfManager.cpp | 13 +++- src/common/function/FunctionUdfManager.h | 9 ++- src/common/function/GraphFunction.h | 26 ++++--- udf/standard_deviation.cpp | 91 ++++++++++++++++++++++ udf/standard_deviation.h | 51 ++++++++++++ 7 files changed, 181 insertions(+), 17 deletions(-) create mode 100644 udf/standard_deviation.cpp create mode 100644 udf/standard_deviation.h diff --git a/conf/nebula-graphd.conf.default b/conf/nebula-graphd.conf.default index 919f94d7f5c..57dc957bc52 100644 --- a/conf/nebula-graphd.conf.default +++ b/conf/nebula-graphd.conf.default @@ -96,3 +96,9 @@ # if use balance data feature, only work if enable_experimental_feature is true --enable_data_balance=true + +# enable udf, written in c++ only for now +--enable_udf=true + +# set the directory where the .so files of udf are stored, when enable_udf is true +--udf_path=/home/nebula/dev/nebula/udf/ \ No newline at end of file diff --git a/src/common/function/FunctionManager.cpp b/src/common/function/FunctionManager.cpp index 4831db89a94..edddeac6dcc 100644 --- a/src/common/function/FunctionManager.cpp +++ b/src/common/function/FunctionManager.cpp @@ -9,6 +9,7 @@ #include +#include "FunctionUdfManager.h" #include "common/base/Base.h" #include "common/datatypes/DataSet.h" #include "common/datatypes/Edge.h" @@ -28,7 +29,6 @@ #include "common/time/TimeUtils.h" #include "common/time/WallClock.h" #include "graph/service/GraphFlags.h" -#include "FunctionUdfManager.h" DEFINE_bool(enable_udf, false, "enable udf"); diff --git a/src/common/function/FunctionUdfManager.cpp b/src/common/function/FunctionUdfManager.cpp index 6176761f8fe..e2ad870d3ee 100644 --- a/src/common/function/FunctionUdfManager.cpp +++ b/src/common/function/FunctionUdfManager.cpp @@ -1,7 +1,13 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + #include "FunctionUdfManager.h" #include #include + #include #include @@ -47,7 +53,7 @@ std::vector getFilesList(const std::string &path, const char *ftype } FunctionUdfManager::create_f *FunctionUdfManager::getGraphFunctionClass(void *func_handle) { - auto *create_func = (create_f *)dlsym(func_handle, "create"); + auto *create_func = reinterpret_cast(dlsym(func_handle, "create")); dlsym_error = dlerror(); if (dlsym_error) { LOG(ERROR) << "Cannot load symbol create: " << dlsym_error; @@ -56,7 +62,7 @@ FunctionUdfManager::create_f *FunctionUdfManager::getGraphFunctionClass(void *fu } FunctionUdfManager::destroy_f *FunctionUdfManager::deleteGraphFunctionClass(void *func_handle) { - auto *destroy_func = (destroy_f *)dlsym(func_handle, "destroy"); + auto *destroy_func = reinterpret_cast(dlsym(func_handle, "destroy")); dlsym_error = dlerror(); if (dlsym_error) { LOG(ERROR) << "Cannot load symbol destroy: " << dlsym_error; @@ -113,7 +119,6 @@ void FunctionUdfManager::initAndLoadSoFunction() { destroy_func(gf); dlclose(func_handle); - } catch (...) { LOG(ERROR) << "load So library Error: " << soPath; } @@ -196,4 +201,4 @@ void FunctionUdfManager::addSoUdfFunction( }; } -} // namespace nebula \ No newline at end of file +} // namespace nebula diff --git a/src/common/function/FunctionUdfManager.h b/src/common/function/FunctionUdfManager.h index 8b0228345cd..2cd1c8e6f2f 100644 --- a/src/common/function/FunctionUdfManager.h +++ b/src/common/function/FunctionUdfManager.h @@ -1,3 +1,8 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + #ifndef COMMON_FUNCTION_FUNCTIONUDFMANAGER_H_ #define COMMON_FUNCTION_FUNCTIONUDFMANAGER_H_ @@ -13,7 +18,8 @@ class FunctionUdfManager { typedef GraphFunction *(create_f)(); typedef void(destroy_f)(GraphFunction *); - static StatusOr getUdfReturnType(const std::string functionName, const std::vector &argsType); + static StatusOr getUdfReturnType(const std::string functionName, + const std::vector &argsType); static StatusOr loadUdfFunction( std::string functionName, size_t arity); @@ -28,7 +34,6 @@ class FunctionUdfManager { void addSoUdfFunction(char *funName, const char *soPath, size_t i, size_t i1, bool b); void initAndLoadSoFunction(); - }; } // namespace nebula diff --git a/src/common/function/GraphFunction.h b/src/common/function/GraphFunction.h index 97f2406c74a..4848b160846 100644 --- a/src/common/function/GraphFunction.h +++ b/src/common/function/GraphFunction.h @@ -1,7 +1,13 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + #ifndef COMMON_FUNCTION_GRAPHFUNCTION_H #define COMMON_FUNCTION_GRAPHFUNCTION_H #include + #include "common/datatypes/Value.h" class GraphFunction; @@ -10,23 +16,23 @@ extern "C" GraphFunction *create(); extern "C" void destroy(GraphFunction *function); class GraphFunction { -public: - virtual ~GraphFunction() = default; + public: + virtual ~GraphFunction() = default; - virtual char *name() = 0; + virtual char *name() = 0; - virtual std::vector> inputType() = 0; + virtual std::vector> inputType() = 0; - virtual nebula::Value::Type returnType() = 0; + virtual nebula::Value::Type returnType() = 0; - virtual size_t minArity() = 0; + virtual size_t minArity() = 0; - virtual size_t maxArity() = 0; + virtual size_t maxArity() = 0; - virtual bool isPure() = 0; + virtual bool isPure() = 0; - virtual nebula::Value body(const std::vector> &args) = 0; + virtual nebula::Value body( + const std::vector> &args) = 0; }; #endif // COMMON_FUNCTION_GRAPHFUNCTION_H - diff --git a/udf/standard_deviation.cpp b/udf/standard_deviation.cpp new file mode 100644 index 00000000000..7be5e157249 --- /dev/null +++ b/udf/standard_deviation.cpp @@ -0,0 +1,91 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#include "standard_deviation.h" + +#include +#include + +#include "../src/common/datatypes/List.h" + +extern "C" GraphFunction *create() { + return new standard_deviation; +} +extern "C" void destroy(GraphFunction *function) { + delete function; +} + +char *standard_deviation::name() { + const char *name = "standard_deviation"; + return const_cast(name); +} + +std::vector> standard_deviation::inputType() { + std::vector vtp = {nebula::Value::Type::LIST}; + std::vector> vvtp = {vtp}; + return vvtp; +} + +nebula::Value::Type standard_deviation::returnType() { + return nebula::Value::Type::FLOAT; +} + +size_t standard_deviation::minArity() { + return 1; +} + +size_t standard_deviation::maxArity() { + return 1; +} + +bool standard_deviation::isPure() { + return true; +} + +double caculate_standard_deviation(const std::vector &numbers) { + double sum = 0; + for (double number : numbers) { + sum += number; + } + double average = sum / numbers.size(); + + double variance = 0; + for (double number : numbers) { + double difference = number - average; + variance += difference * difference; + } + variance /= numbers.size(); + + return sqrt(variance); +} + +nebula::Value standard_deviation::body( + const std::vector> &args) { + switch (args[0].get().type()) { + case nebula::Value::Type::NULLVALUE: { + return nebula::Value::kNullValue; + } + case nebula::Value::Type::LIST: { + std::vector numbers; + auto list = args[0].get().getList(); + auto size = list.size(); + + for (int i = 0; i < size; i++) { + auto &value = list[i]; + if (value.isInt()) { + numbers.push_back(value.getInt()); + } else if (value.isFloat()) { + numbers.push_back(value.getFloat()); + } else { + return nebula::Value::kNullValue; + } + } + return nebula::Value(caculate_standard_deviation(numbers)); + } + default: { + return nebula::Value::kNullValue; + } + } +} diff --git a/udf/standard_deviation.h b/udf/standard_deviation.h new file mode 100644 index 00000000000..5a40607fb0e --- /dev/null +++ b/udf/standard_deviation.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2020 vesoft inc. All rights reserved. + * + * This source code is licensed under Apache 2.0 License. + */ + +#ifndef UDF_PROJECT_STANDARD_DEVIATION_H +#define UDF_PROJECT_STANDARD_DEVIATION_H + +#include "../src/common/function/GraphFunction.h" + +// Example of a UDF function that calculates the standard deviation of a set of numbers. +// > YIELD standard_deviation([1,2,3]) +// +-----------------------------+ +// | standard_deviation([1,2,3]) | +// +-----------------------------+ +// | 0.816496580927726 | +// +-----------------------------+ + +// > YIELD standard_deviation([1,1,1]) +// +-----------------------------+ +// | standard_deviation([1,1,1]) | +// +-----------------------------+ +// | 0.0 | +// +-----------------------------+ + +// > GO 1 TO 2 STEPS FROM "player100" OVER follow YIELD properties(edge).degree AS d | yield collect($-.d) +// +--------------------------+ +// | collect($-.d) | +// +--------------------------+ +// | [95, 95, 95, 90, 95, 90] | +// +--------------------------+ + + +class standard_deviation : public GraphFunction { + public: + char *name() override; + + std::vector> inputType() override; + + nebula::Value::Type returnType() override; + + size_t minArity() override; + + size_t maxArity() override; + + bool isPure() override; + + nebula::Value body(const std::vector> &args) override; +}; + +#endif // UDF_PROJECT_STANDARD_DEVIATION_H