Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jievince committed Dec 6, 2022
1 parent 7070bbc commit 8629c57
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 95 deletions.
95 changes: 52 additions & 43 deletions src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
/* Copyright (c) 2022 vesoft inc. All rights reserved.
*
* This source code is licensed under Apache 2.0 License.
*/
// Copyright (c) 2022 vesoft inc. All rights reserved.
//
// This source code is licensed under Apache 2.0 License.

#include "graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h"

Expand All @@ -26,7 +25,7 @@ OptimizeLeftJoinPredicateRule::OptimizeLeftJoinPredicateRule() {

const Pattern& OptimizeLeftJoinPredicateRule::pattern() const {
static Pattern pattern = Pattern::create(
PlanNode::Kind::kBiLeftJoin,
PlanNode::Kind::kHashLeftJoin,
{Pattern::create(PlanNode::Kind::kUnknown),
Pattern::create(PlanNode::Kind::kProject,
{Pattern::create(PlanNode::Kind::kAppendVertices,
Expand All @@ -38,45 +37,35 @@ StatusOr<OptRule::TransformResult> OptimizeLeftJoinPredicateRule::transform(
OptContext* octx, const MatchedResult& matched) const {
auto* leftJoinGroupNode = matched.node;
auto* leftJoinGroup = leftJoinGroupNode->group();
auto* leftJoin = static_cast<graph::BiLeftJoin*>(leftJoinGroupNode->node());
auto* leftJoin = static_cast<graph::HashLeftJoin*>(leftJoinGroupNode->node());

auto* projectGroupNode = matched.dependencies[1].node;
auto* projectGroup = projectGroupNode->group();
UNUSED(projectGroup);

auto* project = static_cast<graph::Project*>(projectGroupNode->node());

auto* appendVerticesGroup = matched.dependencies[1].dependencies[0].node->group();
UNUSED(appendVerticesGroup);
auto* appendVerticesGroupNode = matched.dependencies[1].dependencies[0].node;
auto appendVertices =
static_cast<graph::AppendVertices*>(matched.dependencies[1].dependencies[0].node->node());

auto* traverseGroup = matched.dependencies[1].dependencies[0].node->group();
UNUSED(traverseGroup);
auto traverse = static_cast<graph::Traverse*>(
matched.dependencies[1].dependencies[0].dependencies[0].node->node());

auto& avColNames = appendVertices->colNames();
DCHECK_GE(avColNames.size(), 1);
auto& avNodeAlias = avColNames.back();
auto& avNodeAlias = appendVertices->nodeAlias();

auto& tvColNames = traverse->colNames();
DCHECK_GE(tvColNames.size(), 1);
auto& tvEdgeAlias = traverse->colNames().back();
auto& tvEdgeAlias = traverse->edgeAlias();

auto& hashKeys = leftJoin->hashKeys();
auto& probeKeys = leftJoin->probeKeys();

// Use visitor to collect all function `id` in the hashKeys

std::vector<size_t> hashKeyIdx;
bool found = false;
size_t hashKeyIdx;
for (size_t i = 0; i < hashKeys.size(); ++i) {
auto* key = hashKeys[i];
if (key->kind() != Expression::Kind::kFunctionCall) {
auto* hashKey = hashKeys[i];
if (hashKey->kind() != Expression::Kind::kFunctionCall) {
continue;
}
auto* func = static_cast<FunctionCallExpression*>(key);
if (func->name() != "id" || func->name() != "_joinkey") {
auto* func = static_cast<FunctionCallExpression*>(hashKey);
if (func->name() != "id" && func->name() != "_joinkey") {
continue;
}
auto& args = func->args()->args();
Expand All @@ -87,66 +76,86 @@ StatusOr<OptRule::TransformResult> OptimizeLeftJoinPredicateRule::transform(
}
auto& alias = static_cast<InputPropertyExpression*>(arg)->prop();
if (alias != avNodeAlias) continue;
// FIXME(jie): Must check if probe keys contain the same key
hashKeyIdx.emplace_back(i);
// Must check if probe keys contain the same key
if (*probeKeys[i] != *hashKey) {
return TransformResult::noTransform();
}
if (found) {
return TransformResult::noTransform();
}
hashKeyIdx = i;
found = true;
}
if (hashKeyIdx.size() != 1) {
if (!found) {
return TransformResult::noTransform();
}

std::vector<size_t> prjIdx;
found = false;
size_t prjIdx;
for (size_t i = 0; i < project->columns()->size(); ++i) {
const auto* col = project->columns()->columns()[i];
if (col->expr()->kind() != Expression::Kind::kInputProperty) {
continue;
}
auto* inputProp = static_cast<InputPropertyExpression*>(col->expr());
if (inputProp->prop() != avNodeAlias) continue;
prjIdx.push_back(i);
if (found) {
return TransformResult::noTransform();
}
prjIdx = i;
found = true;
}
if (prjIdx.size() != 1) {
if (!found) {
return TransformResult::noTransform();
}

auto* pool = octx->qctx()->objPool();
// Let the new project generate expr `none_direct_dst($-.tvEdgeAlias)`, and let the new left join
// use it as hash key
// Let the new project generate expr `none_direct_dst($-.tvEdgeAlias)`,
// and let the new left join use it as hash key
auto* args = ArgumentList::make(pool);
args->addArgument(InputPropertyExpression::make(pool, tvEdgeAlias));
auto* newPrjExpr = FunctionCallExpression::make(pool, "none_direct_dst", args);

auto* newYieldColumns = pool->makeAndAdd<YieldColumns>();
for (size_t i = 0; i < project->columns()->size(); ++i) {
if (i == prjIdx[0]) {
newYieldColumns->addColumn(pool->makeAndAdd<YieldColumn>(newPrjExpr, newPrjExpr->toString()));
if (i == prjIdx) {
newYieldColumns->addColumn(new YieldColumn(newPrjExpr, avNodeAlias));
} else {
newYieldColumns->addColumn(project->columns()->columns()[i]);
newYieldColumns->addColumn(project->columns()->columns()[i]->clone().release());
}
}
auto* newProject = graph::Project::make(octx->qctx(), nullptr, newYieldColumns);

auto* newHashExpr = InputPropertyExpression::make(pool, newPrjExpr->toString());
// $-.`avNodeAlias`
auto* newHashExpr = InputPropertyExpression::make(pool, avNodeAlias);
std::vector<Expression*> newHashKeys;
for (size_t i = 0; i < hashKeys.size(); ++i) {
if (i == hashKeyIdx[0]) {
if (i == hashKeyIdx) {
newHashKeys.emplace_back(newHashExpr);
} else {
newHashKeys.emplace_back(hashKeys[i]);
}
}
auto* newLeftJoin =
graph::BiLeftJoin::make(octx->qctx(), nullptr, nullptr, newHashKeys, probeKeys);
graph::HashLeftJoin::make(octx->qctx(), nullptr, nullptr, newHashKeys, probeKeys);

TransformResult result;
result.eraseAll = true;

newProject->setInputVar(appendVertices->inputVar());
newProject->setOutputVar(project->outputVar());
auto newProjectGroup = OptGroup::create(octx);
auto* newProjectGroupNode = newProjectGroup->makeGroupNode(newProject);
newProjectGroupNode->setDeps(projectGroupNode->dependencies());

newLeftJoin->setDep(1, newProject);
newProjectGroupNode->setDeps(appendVerticesGroupNode->dependencies());

newLeftJoin->setLeftVar(leftJoin->leftInputVar());
newLeftJoin->setRightVar(newProject->outputVar());
newLeftJoin->setOutputVar(leftJoin->outputVar());
// LOG the col names of newLeftJoin
auto& newLeftJoinColNames = newLeftJoin->colNames();
LOG(ERROR) << "newLeftJoinColNames.size(): " << newLeftJoinColNames.size();
for (auto& colName : newLeftJoinColNames) {
LOG(ERROR) << "colName: " << colName;
}
auto* newLeftJoinGroupNode = OptGroupNode::create(octx, newLeftJoin, leftJoinGroup);
newLeftJoinGroupNode->dependsOn(leftJoinGroupNode->dependencies()[0]);
newLeftJoinGroupNode->dependsOn(newProjectGroup);
Expand Down
36 changes: 18 additions & 18 deletions src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@

namespace nebula {
namespace opt {

/*
Before:
BiLeftJoin({id(v)}, id(v))
/ \
... Project
\
AppendVertices(v)
\
Traverse(e)
After:
BiLeftJoin({id(v)}, none_direct_dst(e))
/ \
... Project
\
Traverse(e)
*/
/* Before:
* HashLeftJoin({id(v)}, {id(v)})
* / \
* ... Project
* / \
* AppendVertices(v) AppendVertices(v)
* / \
* ... Traverse(e)
*
* After:
* HashLeftJoin({id(v)}, {$-.v})
* / \
* ... Project(..., none_direct_dst(e) AS v)
* / \
* AppendVertices(v) Traverse(e)
* /
* ...
*/
class OptimizeLeftJoinPredicateRule final : public OptRule {
public:
const Pattern &pattern() const override;
Expand Down
68 changes: 34 additions & 34 deletions tests/tck/features/optimizer/OptimizeLeftJoinPredicateRule.feature
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,47 @@ Feature: Optimize left join predicate
id(friendTeam) AS teamId,
friendTeam.team.name AS teamName,
numFriends
ORDER BY numFriends DESC
LIMIT 20
ORDER BY teamName DESC
"""
Then the result should be, in any order, with relax comparison:
Then the result should be, in order, with relax comparison:
| teamId | teamName | numFriends |
| "Clippers" | "Clippers" | 0 |
| "Bulls" | "Bulls" | 0 |
| "Spurs" | "Spurs" | 0 |
| "Thunders" | "Thunders" | 0 |
| "Hornets" | "Hornets" | 0 |
| "Warriors" | "Warriors" | 0 |
| "Hawks" | "Hawks" | 0 |
| "Kings" | "Kings" | 0 |
| "Magic" | "Magic" | 0 |
| "Trail Blazers" | "Trail Blazers" | 0 |
| "Lakers" | "Lakers" | 0 |
| "Grizzlies" | "Grizzlies" | 0 |
| "Thunders" | "Thunders" | 0 |
| "Suns" | "Suns" | 0 |
| "Spurs" | "Spurs" | 0 |
| "Rockets" | "Rockets" | 0 |
| "Cavaliers" | "Cavaliers" | 0 |
| "Raptors" | "Raptors" | 0 |
| "Pistons" | "Pistons" | 0 |
| "Magic" | "Magic" | 0 |
| "Lakers" | "Lakers" | 0 |
| "Kings" | "Kings" | 0 |
| "Jazz" | "Jazz" | 0 |
| "Hornets" | "Hornets" | 0 |
| "Heat" | "Heat" | 0 |
| "Hawks" | "Hawks" | 0 |
| "Grizzlies" | "Grizzlies" | 0 |
| "Clippers" | "Clippers" | 0 |
| "Celtics" | "Celtics" | 0 |
| "Cavaliers" | "Cavaliers" | 0 |
| "Bulls" | "Bulls" | 0 |
| "76ers" | "76ers" | 0 |
| "Heat" | "Heat" | 0 |
| "Jazz" | "Jazz" | 0 |
And the execution plan should be:
| id | name | dependencies | profiling data | operator info |
| 21 | TopN | 18 | | |
| 18 | Project | 17 | | |
| 17 | Aggregate | 16 | | |
| 16 | BiLeftJoin | 10,15 | | |
| 10 | Dedup | 28 | | |
| 28 | Project | 22 | | |
| 22 | Filter | 26 | | |
| 26 | AppendVertices | 25 | | |
| 25 | Traverse | 24 | | |
| 24 | Traverse | 2 | | |
| 2 | Dedup | 1 | | |
| 1 | PassThrough | 3 | | |
| 3 | Start | | | |
| 15 | Project | 14 | | |
| 14 | Traverse | 12 | | |
| 12 | Traverse | 11 | | |
| 11 | Argument | | | |
| id | name | dependencies | operator info |
| 21 | Sort | 18 | |
| 18 | Project | 17 | |
| 17 | Aggregate | 16 | |
| 16 | HashLeftJoin | 10,15 | {"probeKeys": ["_joinkey($-.friendTeam)", "_joinkey($-.friend)"], "hashKeys": ["$-.friendTeam", "_joinkey($-.friend)"]} |
| 10 | Dedup | 28 | |
| 28 | Project | 22 | |
| 22 | Filter | 26 | |
| 26 | AppendVertices | 25 | |
| 25 | Traverse | 24 | |
| 24 | Traverse | 2 | |
| 2 | Dedup | 1 | |
| 1 | PassThrough | 3 | |
| 3 | Start | | |
| 15 | Project | 14 | {"columns": ["$-.friend AS friend", "$-.friend2 AS friend2", "none_direct_dst($-.__VAR_3) AS friendTeam"]} |
| 14 | Traverse | 12 | |
| 12 | Traverse | 11 | |
| 11 | Argument | | |

0 comments on commit 8629c57

Please sign in to comment.