Skip to content

Commit

Permalink
support cypher parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
czpmango committed Sep 28, 2021
1 parent 56af984 commit c8dbfc8
Show file tree
Hide file tree
Showing 40 changed files with 318 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/common/expression/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ nebula_add_library(
PredicateExpression.cpp
ListComprehensionExpression.cpp
ReduceExpression.cpp
ParameterExpression.cpp
)

nebula_add_subdirectory(test)
3 changes: 3 additions & 0 deletions src/common/expression/ExprVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "common/expression/LabelExpression.h"
#include "common/expression/ListComprehensionExpression.h"
#include "common/expression/LogicalExpression.h"
#include "common/expression/ParameterExpression.h"
#include "common/expression/PathBuildExpression.h"
#include "common/expression/PredicateExpression.h"
#include "common/expression/PropertyExpression.h"
Expand Down Expand Up @@ -89,6 +90,8 @@ class ExprVisitor {
virtual void visit(ReduceExpression *expr) = 0;
// subscript range expression
virtual void visit(SubscriptRangeExpression *expr) = 0;
// parameter expression
virtual void visit(ParameterExpression *expr) = 0;
};

} // namespace nebula
Expand Down
9 changes: 9 additions & 0 deletions src/common/expression/Expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "common/expression/LabelExpression.h"
#include "common/expression/ListComprehensionExpression.h"
#include "common/expression/LogicalExpression.h"
#include "common/expression/ParameterExpression.h"
#include "common/expression/PathBuildExpression.h"
#include "common/expression/PredicateExpression.h"
#include "common/expression/PropertyExpression.h"
Expand Down Expand Up @@ -497,6 +498,11 @@ Expression* Expression::decode(ObjectPool* pool, Expression::Decoder& decoder) {
exp->resetFrom(decoder);
return exp;
}
case Expression::Kind::kParam: {
exp = ParameterExpression::make(pool);
exp->resetFrom(decoder);
return exp;
}
case Expression::Kind::kTSPrefix:
case Expression::Kind::kTSWildcard:
case Expression::Kind::kTSRegexp:
Expand Down Expand Up @@ -719,6 +725,9 @@ std::ostream& operator<<(std::ostream& os, Expression::Kind kind) {
case Expression::Kind::kReduce:
os << "Reduce";
break;
case Expression::Kind::kParam:
os << "Parameter";
break;
}
return os;
}
Expand Down
1 change: 1 addition & 0 deletions src/common/expression/Expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class Expression {
kIsNotEmpty,

kSubscriptRange,
kParam,
};

Expression(ObjectPool* pool, Kind kind);
Expand Down
30 changes: 30 additions & 0 deletions src/common/expression/ParameterExpression.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

#include "common/expression/ParameterExpression.h"

#include "common/expression/ExprVisitor.h"

namespace nebula {

const Value& ParameterExpression::eval(ExpressionContext& ectx) { return ectx.getVar(name_); }

std::string ParameterExpression::toString() const { return name_; }

bool ParameterExpression::operator==(const Expression& rhs) const {
return kind_ == rhs.kind() && name_ == rhs.toString();
}

void ParameterExpression::writeTo(Encoder& encoder) const {
encoder << kind_;
encoder << name_;
}

void ParameterExpression::resetFrom(Decoder& decoder) { name_ = decoder.readStr(); }

void ParameterExpression::accept(ExprVisitor* visitor) { visitor->visit(this); }

} // namespace nebula
47 changes: 47 additions & 0 deletions src/common/expression/ParameterExpression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/

#pragma once

#include "common/expression/Expression.h"

// The ParameterExpression use for parameterized statement
namespace nebula {

class ParameterExpression : public Expression {
public:
ParameterExpression& operator=(const ParameterExpression& rhs) = delete;
ParameterExpression& operator=(ParameterExpression&&) = delete;

static ParameterExpression* make(ObjectPool* pool, const std::string& name = "") {
return pool->add(new ParameterExpression(pool, name));
}

bool operator==(const Expression& rhs) const override;

const Value& eval(ExpressionContext& ctx) override;

const std::string& name() const { return name_; }

std::string toString() const override;

void accept(ExprVisitor* visitor) override;

Expression* clone() const override { return ParameterExpression::make(pool_, name()); }

protected:
explicit ParameterExpression(ObjectPool* pool, const std::string& name = "")
: Expression(pool, Kind::kParam), name_(name) {}

void writeTo(Encoder& encoder) const override;
void resetFrom(Decoder& decoder) override;

protected:
std::string name_;
Value result_;
};

} // namespace nebula
18 changes: 18 additions & 0 deletions src/common/expression/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,24 @@ nebula_add_test(
${THRIFT_LIBRARIES}
)

nebula_add_test(
NAME param_expression_test
SOURCES ParameterExpressionTest.cpp
OBJECTS
$<TARGET_OBJECTS:base_obj>
$<TARGET_OBJECTS:expression_obj>
$<TARGET_OBJECTS:datatypes_obj>
$<TARGET_OBJECTS:expr_ctx_mock_obj>
$<TARGET_OBJECTS:function_manager_obj>
$<TARGET_OBJECTS:agg_function_manager_obj>
$<TARGET_OBJECTS:time_obj>
$<TARGET_OBJECTS:time_utils_obj>
$<TARGET_OBJECTS:fs_obj>
LIBRARIES
gtest
${THRIFT_LIBRARIES}
)

nebula_add_test(
NAME list_comprehension_expression_test
SOURCES ListComprehensionExpressionTest.cpp
Expand Down
6 changes: 6 additions & 0 deletions src/common/expression/test/EncodeDecodeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ TEST(ExpressionEncodeDecode, LabelExpression) {
ASSERT_EQ(*origin, *decoded);
}

TEST(ExpressionEncodeDecode, ParameterExpression) {
auto origin = ParameterExpression::make(&pool, "name");
auto decoded = Expression::decode(&pool, Expression::encode(*origin));
ASSERT_EQ(*origin, *decoded);
}

TEST(ExpressionEncodeDecode, CaseExpression) {
{
// CASE 23 WHEN 24 THEN 1 END
Expand Down
2 changes: 2 additions & 0 deletions src/common/expression/test/ExpressionContextMock.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ std::unordered_map<std::string, Value> ExpressionContextMock::vals_ = {
{"path_edge2", Value(Edge("2", "3", 1, "edge", 0, {}))},
{"path_v2", Value(Vertex("3", {}))},
{"path_edge3", Value(Edge("3", "4", 1, "edge", 0, {}))},
{"param1", Value(1)},
{"param2", Value(List(std::vector<Value>{1, 2, 3, 4, 5, 6, 7, 8}))},
};

Value ExpressionContextMock::getColumn(int32_t index) const {
Expand Down
31 changes: 31 additions & 0 deletions src/common/expression/test/ParameterExpressionTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/* Copyright (c) 2021 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License,
* attached with Common Clause Condition 1.0, found in the LICENSES directory.
*/
#include "common/expression/test/TestBase.h"

namespace nebula {

class ParameterExpressionTest : public ExpressionTest {};

TEST_F(ParameterExpressionTest, ParamExprToString) {
auto expr = ParameterExpression::make(&pool, "$param1");
ASSERT_EQ("$param1", expr->toString());
}

TEST_F(ParameterExpressionTest, ParamEvaluate) {
auto expr = ParameterExpression::make(&pool, "param1");
auto value = Expression::eval(expr, gExpCtxt);
ASSERT_TRUE(value.isInt());
ASSERT_EQ(1, value.getInt());
}
} // namespace nebula

int main(int argc, char **argv) {
testing::InitGoogleTest(&argc, argv);
folly::init(&argc, &argv, true);
google::SetStderrLogging(google::INFO);

return RUN_ALL_TESTS();
}
1 change: 1 addition & 0 deletions src/common/expression/test/TestBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "common/expression/LabelExpression.h"
#include "common/expression/ListComprehensionExpression.h"
#include "common/expression/LogicalExpression.h"
#include "common/expression/ParameterExpression.h"
#include "common/expression/PathBuildExpression.h"
#include "common/expression/PredicateExpression.h"
#include "common/expression/PropertyExpression.h"
Expand Down
6 changes: 6 additions & 0 deletions src/graph/context/QueryContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ void QueryContext::init() {
objPool_ = std::make_unique<ObjectPool>();
ep_ = std::make_unique<ExecutionPlan>();
ectx_ = std::make_unique<ExecutionContext>();
// copy parameterMap into ExecutionContext
if (rctx_) {
for (auto item : rctx_->parameterMap()) {
ectx_->setValue(std::move(item.first), std::move(item.second));
}
}
idGen_ = std::make_unique<IdGenerator>(0);
symTable_ = std::make_unique<SymbolTable>(objPool_.get());
vctx_ = std::make_unique<ValidateContext>(std::make_unique<AnonVarGenerator>(symTable_.get()));
Expand Down
2 changes: 2 additions & 0 deletions src/graph/context/QueryContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class QueryContext {

bool isKilled() const { return killed_.load(); }

bool existParameter(const std::string& param) const { return ectx_->exist(param); }

private:
void init();

Expand Down
2 changes: 1 addition & 1 deletion src/graph/context/QueryExpressionContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class QueryExpressionContext final : public ExpressionContext {

void setVar(const std::string&, Value val) override;

QueryExpressionContext& operator()(Iterator* iter) {
QueryExpressionContext& operator()(Iterator* iter = nullptr) {
iter_ = iter;
return *this;
}
Expand Down
23 changes: 21 additions & 2 deletions src/graph/service/GraphService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ void GraphService::signout(int64_t sessionId) {

folly::Future<ExecutionResponse> GraphService::future_execute(int64_t sessionId,
const std::string& query) {
std::unordered_map<std::string, Value> params;
return future_executeWithParameter(sessionId, query, std::move(params)).get();
}

folly::Future<ExecutionResponse> GraphService::future_executeWithParameter(
int64_t sessionId,
const std::string& query,
const std::unordered_map<std::string, Value>& parameterMap) {
auto ctx = std::make_unique<RequestContext<ExecutionResponse>>();
ctx->setQuery(query);
ctx->setRunner(getThreadManager());
Expand All @@ -129,7 +137,7 @@ folly::Future<ExecutionResponse> GraphService::future_execute(int64_t sessionId,
ctx->finish();
return future;
}
auto cb = [this, sessionId, ctx = std::move(ctx)](
auto cb = [this, sessionId, ctx = std::move(ctx), parameterMap = std::move(parameterMap)](
StatusOr<std::shared_ptr<ClientSession>> ret) mutable {
if (!ret.ok()) {
LOG(ERROR) << "Get session for sessionId: " << sessionId << " failed: " << ret.status();
Expand All @@ -147,6 +155,7 @@ folly::Future<ExecutionResponse> GraphService::future_execute(int64_t sessionId,
return ctx->finish();
}
ctx->setSession(std::move(sessionPtr));
ctx->setParameterMap(parameterMap);
queryEngine_->execute(std::move(ctx));
};
sessionManager_->findSession(sessionId, getThreadManager()).thenValue(std::move(cb));
Expand All @@ -155,7 +164,17 @@ folly::Future<ExecutionResponse> GraphService::future_execute(int64_t sessionId,

folly::Future<std::string> GraphService::future_executeJson(int64_t sessionId,
const std::string& query) {
auto rawResp = future_execute(sessionId, query).get();
std::unordered_map<std::string, Value> params;
auto rawResp = future_executeWithParameter(sessionId, query, std::move(params)).get();
auto respJsonObj = rawResp.toJson();
return folly::toJson(respJsonObj);
}

folly::Future<std::string> GraphService::future_executeJsonWithParameter(
int64_t sessionId,
const std::string& query,
const std::unordered_map<std::string, Value>& parameterMap) {
auto rawResp = future_executeWithParameter(sessionId, query, parameterMap).get();
auto respJsonObj = rawResp.toJson();
return folly::toJson(respJsonObj);
}
Expand Down
10 changes: 10 additions & 0 deletions src/graph/service/GraphService.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,19 @@ class GraphService final : public cpp2::GraphServiceSvIf {
folly::Future<ExecutionResponse> future_execute(int64_t sessionId,
const std::string& stmt) override;

folly::Future<ExecutionResponse> future_executeWithParameter(
int64_t sessionId,
const std::string& stmt,
const std::unordered_map<std::string, Value>& parameterMap) override;

folly::Future<std::string> future_executeJson(int64_t sessionId,
const std::string& stmt) override;

folly::Future<std::string> future_executeJsonWithParameter(
int64_t sessionId,
const std::string& stmt,
const std::unordered_map<std::string, Value>& parameterMap) override;

private:
bool auth(const std::string& username, const std::string& password);

Expand Down
7 changes: 7 additions & 0 deletions src/graph/service/RequestContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ class RequestContext final : public cpp::NonCopyable, public cpp::NonMovable {

GraphSessionManager* sessionMgr() const { return sessionMgr_; }

void setParameterMap(std::unordered_map<std::string, Value> parameterMap) {
parameterMap_ = std::move(parameterMap);
}

const std::unordered_map<std::string, Value>& parameterMap() const { return parameterMap_; }

private:
time::Duration duration_;
std::string query_;
Expand All @@ -76,6 +82,7 @@ class RequestContext final : public cpp::NonCopyable, public cpp::NonMovable {
std::shared_ptr<ClientSession> session_;
folly::Executor* runner_{nullptr};
GraphSessionManager* sessionMgr_{nullptr};
std::unordered_map<std::string, Value> parameterMap_;
};

} // namespace graph
Expand Down
1 change: 1 addition & 0 deletions src/graph/visitor/DeducePropsVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class DeducePropsVisitor : public ExprVisitorImpl {
void visit(VertexExpression* expr) override;
void visit(EdgeExpression* expr) override;
void visit(ColumnExpression* expr) override;
void visit(ParameterExpression*) override {}

void visitEdgePropExpr(PropertyExpression* expr);
void reportError(const Expression* expr);
Expand Down
5 changes: 5 additions & 0 deletions src/graph/visitor/DeduceTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ void DeduceTypeVisitor::visit(ConstantExpression *expr) {
type_ = expr->eval(ctx(nullptr)).type();
}

void DeduceTypeVisitor::visit(ParameterExpression *expr) {
QueryExpressionContext ctx(qctx_->ectx());
type_ = expr->eval(ctx()).type();
}

void DeduceTypeVisitor::visit(UnaryExpression *expr) {
expr->operand()->accept(this);
if (!ok()) return;
Expand Down
2 changes: 2 additions & 0 deletions src/graph/visitor/DeduceTypeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class DeduceTypeVisitor final : public ExprVisitor {
void visit(ReduceExpression *expr) override;
// subscript range
void visit(SubscriptRangeExpression *expr) override;
// parameter expression
void visit(ParameterExpression *expr) override;

void visitVertexPropertyExpr(PropertyExpression *expr);

Expand Down
5 changes: 5 additions & 0 deletions src/graph/visitor/EvaluableExprVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ class EvaluableExprVisitor : public ExprVisitorImpl {

void visit(VersionedVariableExpression *) override { isEvaluable_ = false; }

void visit(ParameterExpression *) override {
// TODO: ParameterExpression is evaluable but not foldable (czp)
isEvaluable_ = false;
}

void visit(TagPropertyExpression *) override { isEvaluable_ = false; }

void visit(EdgePropertyExpression *) override { isEvaluable_ = false; }
Expand Down
Loading

0 comments on commit c8dbfc8

Please sign in to comment.