diff --git a/src/planner/ngql/GoPlanner.cpp b/src/planner/ngql/GoPlanner.cpp index 2276f606f..1b3a63eef 100644 --- a/src/planner/ngql/GoPlanner.cpp +++ b/src/planner/ngql/GoPlanner.cpp @@ -224,14 +224,17 @@ PlanNode* GoPlanner::trackStartVid(PlanNode* left, PlanNode* right) { * | * Dep */ -PlanNode* GoPlanner::buildJoinDstPlan(PlanNode* dep, PlanNode* left) { +PlanNode* GoPlanner::buildJoinDstPlan(PlanNode* left) { auto qctx = goCtx_->qctx; auto* pool = qctx->objPool(); + auto start = StartNode::make(qctx); + start->setOutputVar(left->outputVar()); + start->setInputVar(left->outputVar()); // dst is the last column, columnName is "JOIN_DST_VID" auto* dstExpr = ColumnExpression::make(pool, LAST_COL_INDEX); auto* getVertex = GetVertices::make(qctx, - dep, + start, goCtx_->space.id, dstExpr, buildVertexProps(goCtx_->exprProps.dstTagProps()), @@ -258,7 +261,7 @@ PlanNode* GoPlanner::buildJoinDstPlan(PlanNode* dep, PlanNode* left) { VLOG(1) << join->outputVar() << " hasKey: " << hashKey->toString() << " probeKey: " << probeKey->toString(); - std::vector colNames = dep->colNames(); + std::vector colNames = left->colNames(); colNames.insert(colNames.end(), project->colNames().begin(), project->colNames().end()); join->setColNames(std::move(colNames)); @@ -280,6 +283,7 @@ PlanNode* GoPlanner::buildJoinInputPlan(PlanNode* left) { auto right = StartNode::make(qctx); right->setOutputVar(probeName); + right->setInputVar(probeName); right->setColNames(varPtr->colNames); auto* join = InnerJoin::make(qctx, {left, ExecutionContext::kLatestVersion}, @@ -328,12 +332,13 @@ PlanNode* GoPlanner::buildLastStepJoinPlan(PlanNode* gn, PlanNode* join) { } auto* dep = extractSrcEdgePropsFromGN(gn, gn->outputVar()); - dep = goCtx_->joinDst ? buildJoinDstPlan(dep, dep) : dep; + dep = goCtx_->joinDst ? buildJoinDstPlan(dep) : dep; PlanNode* left = nullptr; if (goCtx_->joinInput && join != nullptr) { left = StartNode::make(goCtx_->qctx); left->setOutputVar(join->outputVar()); + left->setInputVar(join->outputVar()); left->setColNames(join->colNames()); } dep = goCtx_->joinInput ? lastStepJoinInput(left, dep) : dep; @@ -374,7 +379,7 @@ PlanNode* GoPlanner::buildOneStepJoinPlan(PlanNode* gn) { } auto* dep = extractSrcEdgePropsFromGN(gn, gn->outputVar()); - dep = goCtx_->joinDst ? buildJoinDstPlan(dep, dep) : dep; + dep = goCtx_->joinDst ? buildJoinDstPlan(dep) : dep; dep = goCtx_->joinInput ? buildJoinInputPlan(dep) : dep; return dep; @@ -463,6 +468,7 @@ SubPlan GoPlanner::mToNStepsPlan(SubPlan& startVidPlan) { auto left = StartNode::make(qctx); left->setOutputVar(joinLeft->outputVar()); + left->setInputVar(joinLeft->outputVar()); left->setColNames(joinLeft->colNames()); trackVid = trackStartVid(left, joinRight); loopBody = trackVid; @@ -471,7 +477,7 @@ SubPlan GoPlanner::mToNStepsPlan(SubPlan& startVidPlan) { if (joinInput || joinDst) { loopBody = extractSrcEdgePropsFromGN(loopBody, gn->outputVar()); - loopBody = joinDst ? buildJoinDstPlan(loopBody, loopBody) : loopBody; + loopBody = joinDst ? buildJoinDstPlan(loopBody) : loopBody; loopBody = joinInput ? lastStepJoinInput(trackVid, loopBody) : loopBody; loopBody = joinInput ? buildJoinInputPlan(loopBody) : loopBody; } diff --git a/src/planner/ngql/GoPlanner.h b/src/planner/ngql/GoPlanner.h index 81dc2f8ec..47aaaba02 100644 --- a/src/planner/ngql/GoPlanner.h +++ b/src/planner/ngql/GoPlanner.h @@ -57,7 +57,7 @@ class GoPlanner final : public Planner { PlanNode* trackStartVid(PlanNode* left, PlanNode* right); - PlanNode* buildJoinDstPlan(PlanNode* dep, PlanNode* left); + PlanNode* buildJoinDstPlan(PlanNode* left); PlanNode* buildJoinInputPlan(PlanNode* dep); diff --git a/src/planner/plan/Logic.h b/src/planner/plan/Logic.h index 5dfb5e4d7..9b6860676 100644 --- a/src/planner/plan/Logic.h +++ b/src/planner/plan/Logic.h @@ -22,7 +22,9 @@ class StartNode final : public PlanNode { private: explicit StartNode(QueryContext* qctx) - : PlanNode(qctx, Kind::kStart) {} + : PlanNode(qctx, Kind::kStart) { + inputVars_.emplace_back(nullptr); + } void cloneMembers(const StartNode&); }; diff --git a/src/scheduler/AsyncMsgNotifyBasedScheduler.cpp b/src/scheduler/AsyncMsgNotifyBasedScheduler.cpp index 46514317b..fa71abb3b 100644 --- a/src/scheduler/AsyncMsgNotifyBasedScheduler.cpp +++ b/src/scheduler/AsyncMsgNotifyBasedScheduler.cpp @@ -44,17 +44,44 @@ folly::Future AsyncMsgNotifyBasedScheduler::doSchedule(Executor* root) c queue2.push(exe); std::vector receivers; - for (auto* dep : exe->depends()) { - auto notVisited = visited.emplace(dep).second; - if (notVisited) { - queue.push(dep); + if (exe->node()->kind() == PlanNode::Kind::kStart) { + // if the leaf node bypass a var, we should check the implicit dependencies. + auto nodeOutputVar = exe->node()->outputVar(); + const auto& writtenBy = qctx_->symTable()->getVar(nodeOutputVar)->writtenBy; + auto refCount = qctx_->symTable()->getVar(nodeOutputVar)->userCount.load(); + VLOG(1) << "var: " << nodeOutputVar + << "refCount: " << refCount + << "writtenBy: " << writtenBy.size() + << " if Exist this node: " + << (writtenBy.find(const_cast(exe->node())) != writtenBy.end()); + if (writtenBy.size() == 2 && + writtenBy.find(const_cast(exe->node())) != writtenBy.end()) { + for (auto& node : writtenBy) { + if (exe->node() == node) { + continue; + } + VLOG(1) << "register notifier to: " << node->id(); + Notifier p; + receivers.emplace_back(p.getFuture()); + auto& notifiers = notifierMap[node->id()]; + notifiers.emplace_back(std::move(p)); + } + } + } else { + for (auto* dep : exe->depends()) { + auto notVisited = visited.emplace(dep).second; + if (notVisited) { + queue.push(dep); + } + Notifier p; + receivers.emplace_back(p.getFuture()); + auto& notifiers = notifierMap[dep->id()]; + notifiers.emplace_back(std::move(p)); } - Notifier p; - receivers.emplace_back(p.getFuture()); - auto& notifiers = notifierMap[dep->id()]; - notifiers.emplace_back(std::move(p)); } - receiverMap.emplace(exe->id(), std::move(receivers)); + auto& receiversHist = receiverMap[exe->id()]; + receiversHist.insert(receiversHist.end(), std::make_move_iterator(receivers.begin()), + std::make_move_iterator(receivers.end())); } while (!queue2.empty()) { @@ -93,7 +120,10 @@ void AsyncMsgNotifyBasedScheduler::scheduleExecutor( break; } default: { - if (exe->depends().empty()) { + VLOG(1) << "node: " << exe->node()->kind() + << "exe: " << exe->node()->outputVar() + << " receivers: " << receivers.size(); + if (exe->depends().empty() && receivers.empty()) { runLeafExecutor(exe, runner, std::move(notifiers)); } else { runExecutor(std::move(receivers), exe, runner, std::move(notifiers)); diff --git a/tests/tck/features/optimizer/PushFilterDownLeftJoinRule.feature b/tests/tck/features/optimizer/PushFilterDownLeftJoinRule.feature index 356278699..9426c522a 100644 --- a/tests/tck/features/optimizer/PushFilterDownLeftJoinRule.feature +++ b/tests/tck/features/optimizer/PushFilterDownLeftJoinRule.feature @@ -2,6 +2,7 @@ # # This source code is licensed under Apache 2.0 License, # attached with Common Clause Condition 1.0, found in the LICENSES directory. +@push_down_join Feature: Push Filter down LeftJoin rule Background: @@ -61,10 +62,9 @@ Feature: Push Filter down LeftJoin rule | 7 | Project | 10 | | | 10 | Filter | 9 | | | 9 | LeftJoin | 12,4 | | - | 12 | Project | 11 | | - | 11 | Filter | 1 | | - | 1 | GetNeighbors | 0 | | + | 12 | Project | 13 | | + | 13 | GetNeighbors | 0 | | | 0 | Start | | | | 4 | Project | 3 | | | 3 | GetVertices | 2 | | - | 2 | Project | 1 | | + | 2 | Project | 13 | |