Skip to content

Commit

Permalink
example of udf build added
Browse files Browse the repository at this point in the history
Co-authored-by: Wey Gu <weyl.gu@gmail.com>
Co-authored-by: Cheng Xuntao <7731943+xtcyclist@users.noreply.github.com>
  • Loading branch information
wey-gu and xtcyclist committed Feb 13, 2023
1 parent 5b8a965 commit 2aa09d6
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 17 deletions.
6 changes: 6 additions & 0 deletions conf/nebula-graphd.conf.default
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,9 @@

# if use balance data feature, only work if enable_experimental_feature is true
--enable_data_balance=true

# enable udf, written in c++ only for now
--enable_udf=true

# set the directory where the .so files of udf are stored, when enable_udf is true
--udf_path=/home/nebula/dev/nebula/udf/
2 changes: 1 addition & 1 deletion src/common/function/FunctionManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

#include <boost/algorithm/string/replace.hpp>

#include "FunctionUdfManager.h"
#include "common/base/Base.h"
#include "common/datatypes/DataSet.h"
#include "common/datatypes/Edge.h"
Expand All @@ -28,7 +29,6 @@
#include "common/time/TimeUtils.h"
#include "common/time/WallClock.h"
#include "graph/service/GraphFlags.h"
#include "FunctionUdfManager.h"

DEFINE_bool(enable_udf, false, "enable udf");

Expand Down
13 changes: 9 additions & 4 deletions src/common/function/FunctionUdfManager.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#include "FunctionUdfManager.h"

#include <dirent.h>
#include <dlfcn.h>

#include <cstring>
#include <iostream>

Expand Down Expand Up @@ -47,7 +53,7 @@ std::vector<std::string> getFilesList(const std::string &path, const char *ftype
}

FunctionUdfManager::create_f *FunctionUdfManager::getGraphFunctionClass(void *func_handle) {
auto *create_func = (create_f *)dlsym(func_handle, "create");
auto *create_func = reinterpret_cast<create_f *>(dlsym(func_handle, "create"));
dlsym_error = dlerror();
if (dlsym_error) {
LOG(ERROR) << "Cannot load symbol create: " << dlsym_error;
Expand All @@ -56,7 +62,7 @@ FunctionUdfManager::create_f *FunctionUdfManager::getGraphFunctionClass(void *fu
}

FunctionUdfManager::destroy_f *FunctionUdfManager::deleteGraphFunctionClass(void *func_handle) {
auto *destroy_func = (destroy_f *)dlsym(func_handle, "destroy");
auto *destroy_func = reinterpret_cast<destroy_f *>(dlsym(func_handle, "destroy"));
dlsym_error = dlerror();
if (dlsym_error) {
LOG(ERROR) << "Cannot load symbol destroy: " << dlsym_error;
Expand Down Expand Up @@ -113,7 +119,6 @@ void FunctionUdfManager::initAndLoadSoFunction() {

destroy_func(gf);
dlclose(func_handle);

} catch (...) {
LOG(ERROR) << "load So library Error: " << soPath;
}
Expand Down Expand Up @@ -196,4 +201,4 @@ void FunctionUdfManager::addSoUdfFunction(
};
}

} // namespace nebula
} // namespace nebula
9 changes: 7 additions & 2 deletions src/common/function/FunctionUdfManager.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#ifndef COMMON_FUNCTION_FUNCTIONUDFMANAGER_H_
#define COMMON_FUNCTION_FUNCTIONUDFMANAGER_H_

Expand All @@ -13,7 +18,8 @@ class FunctionUdfManager {
typedef GraphFunction *(create_f)();
typedef void(destroy_f)(GraphFunction *);

static StatusOr<Value::Type> getUdfReturnType(const std::string functionName, const std::vector<Value::Type> &argsType);
static StatusOr<Value::Type> getUdfReturnType(const std::string functionName,
const std::vector<Value::Type> &argsType);

static StatusOr<const FunctionManager::FunctionAttributes> loadUdfFunction(
std::string functionName, size_t arity);
Expand All @@ -28,7 +34,6 @@ class FunctionUdfManager {

void addSoUdfFunction(char *funName, const char *soPath, size_t i, size_t i1, bool b);
void initAndLoadSoFunction();

};

} // namespace nebula
Expand Down
26 changes: 16 additions & 10 deletions src/common/function/GraphFunction.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#ifndef COMMON_FUNCTION_GRAPHFUNCTION_H
#define COMMON_FUNCTION_GRAPHFUNCTION_H

#include <vector>

#include "common/datatypes/Value.h"

class GraphFunction;
Expand All @@ -10,23 +16,23 @@ extern "C" GraphFunction *create();
extern "C" void destroy(GraphFunction *function);

class GraphFunction {
public:
virtual ~GraphFunction() = default;
public:
virtual ~GraphFunction() = default;

virtual char *name() = 0;
virtual char *name() = 0;

virtual std::vector<std::vector<nebula::Value::Type>> inputType() = 0;
virtual std::vector<std::vector<nebula::Value::Type>> inputType() = 0;

virtual nebula::Value::Type returnType() = 0;
virtual nebula::Value::Type returnType() = 0;

virtual size_t minArity() = 0;
virtual size_t minArity() = 0;

virtual size_t maxArity() = 0;
virtual size_t maxArity() = 0;

virtual bool isPure() = 0;
virtual bool isPure() = 0;

virtual nebula::Value body(const std::vector<std::reference_wrapper<const nebula::Value>> &args) = 0;
virtual nebula::Value body(
const std::vector<std::reference_wrapper<const nebula::Value>> &args) = 0;
};

#endif // COMMON_FUNCTION_GRAPHFUNCTION_H

91 changes: 91 additions & 0 deletions udf/standard_deviation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#include "standard_deviation.h"

#include <cmath>
#include <vector>

#include "../src/common/datatypes/List.h"

extern "C" GraphFunction *create() {
return new standard_deviation;
}
extern "C" void destroy(GraphFunction *function) {
delete function;
}

char *standard_deviation::name() {
const char *name = "standard_deviation";
return const_cast<char *>(name);
}

std::vector<std::vector<nebula::Value::Type>> standard_deviation::inputType() {
std::vector<nebula::Value::Type> vtp = {nebula::Value::Type::LIST};
std::vector<std::vector<nebula::Value::Type>> vvtp = {vtp};
return vvtp;
}

nebula::Value::Type standard_deviation::returnType() {
return nebula::Value::Type::FLOAT;
}

size_t standard_deviation::minArity() {
return 1;
}

size_t standard_deviation::maxArity() {
return 1;
}

bool standard_deviation::isPure() {
return true;
}

double caculate_standard_deviation(const std::vector<double> &numbers) {
double sum = 0;
for (double number : numbers) {
sum += number;
}
double average = sum / numbers.size();

double variance = 0;
for (double number : numbers) {
double difference = number - average;
variance += difference * difference;
}
variance /= numbers.size();

return sqrt(variance);
}

nebula::Value standard_deviation::body(
const std::vector<std::reference_wrapper<const nebula::Value>> &args) {
switch (args[0].get().type()) {
case nebula::Value::Type::NULLVALUE: {
return nebula::Value::kNullValue;
}
case nebula::Value::Type::LIST: {
std::vector<double> numbers;
auto list = args[0].get().getList();
auto size = list.size();

for (int i = 0; i < size; i++) {
auto &value = list[i];
if (value.isInt()) {
numbers.push_back(value.getInt());
} else if (value.isFloat()) {
numbers.push_back(value.getFloat());
} else {
return nebula::Value::kNullValue;
}
}
return nebula::Value(caculate_standard_deviation(numbers));
}
default: {
return nebula::Value::kNullValue;
}
}
}
51 changes: 51 additions & 0 deletions udf/standard_deviation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/* Copyright (c) 2020 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/

#ifndef UDF_PROJECT_STANDARD_DEVIATION_H
#define UDF_PROJECT_STANDARD_DEVIATION_H

#include "../src/common/function/GraphFunction.h"

// Example of a UDF function that calculates the standard deviation of a set of numbers.
// > YIELD standard_deviation([1,2,3])
// +-----------------------------+
// | standard_deviation([1,2,3]) |
// +-----------------------------+
// | 0.816496580927726 |
// +-----------------------------+

// > YIELD standard_deviation([1,1,1])
// +-----------------------------+
// | standard_deviation([1,1,1]) |
// +-----------------------------+
// | 0.0 |
// +-----------------------------+

// > GO 1 TO 2 STEPS FROM "player100" OVER follow YIELD properties(edge).degree AS d | yield collect($-.d)
// +--------------------------+
// | collect($-.d) |
// +--------------------------+
// | [95, 95, 95, 90, 95, 90] |
// +--------------------------+


class standard_deviation : public GraphFunction {
public:
char *name() override;

std::vector<std::vector<nebula::Value::Type>> inputType() override;

nebula::Value::Type returnType() override;

size_t minArity() override;

size_t maxArity() override;

bool isPure() override;

nebula::Value body(const std::vector<std::reference_wrapper<const nebula::Value>> &args) override;
};

#endif // UDF_PROJECT_STANDARD_DEVIATION_H

0 comments on commit 2aa09d6

Please sign in to comment.