diff --git a/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.cpp b/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.cpp index 2b61a62867c..796e846a4c5 100644 --- a/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.cpp +++ b/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.cpp @@ -1,7 +1,6 @@ -/* Copyright (c) 2022 vesoft inc. All rights reserved. - * - * This source code is licensed under Apache 2.0 License. - */ +// Copyright (c) 2022 vesoft inc. All rights reserved. +// +// This source code is licensed under Apache 2.0 License. #include "graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h" @@ -26,7 +25,7 @@ OptimizeLeftJoinPredicateRule::OptimizeLeftJoinPredicateRule() { const Pattern& OptimizeLeftJoinPredicateRule::pattern() const { static Pattern pattern = Pattern::create( - PlanNode::Kind::kBiLeftJoin, + PlanNode::Kind::kHashLeftJoin, {Pattern::create(PlanNode::Kind::kUnknown), Pattern::create(PlanNode::Kind::kProject, {Pattern::create(PlanNode::Kind::kAppendVertices, @@ -38,45 +37,35 @@ StatusOr OptimizeLeftJoinPredicateRule::transform( OptContext* octx, const MatchedResult& matched) const { auto* leftJoinGroupNode = matched.node; auto* leftJoinGroup = leftJoinGroupNode->group(); - auto* leftJoin = static_cast(leftJoinGroupNode->node()); + auto* leftJoin = static_cast(leftJoinGroupNode->node()); auto* projectGroupNode = matched.dependencies[1].node; - auto* projectGroup = projectGroupNode->group(); - UNUSED(projectGroup); - auto* project = static_cast(projectGroupNode->node()); - auto* appendVerticesGroup = matched.dependencies[1].dependencies[0].node->group(); - UNUSED(appendVerticesGroup); + auto* appendVerticesGroupNode = matched.dependencies[1].dependencies[0].node; auto appendVertices = static_cast(matched.dependencies[1].dependencies[0].node->node()); - auto* traverseGroup = matched.dependencies[1].dependencies[0].node->group(); - UNUSED(traverseGroup); auto traverse = static_cast( matched.dependencies[1].dependencies[0].dependencies[0].node->node()); - auto& avColNames = appendVertices->colNames(); - DCHECK_GE(avColNames.size(), 1); - auto& avNodeAlias = avColNames.back(); + auto& avNodeAlias = appendVertices->nodeAlias(); - auto& tvColNames = traverse->colNames(); - DCHECK_GE(tvColNames.size(), 1); - auto& tvEdgeAlias = traverse->colNames().back(); + auto& tvEdgeAlias = traverse->edgeAlias(); auto& hashKeys = leftJoin->hashKeys(); auto& probeKeys = leftJoin->probeKeys(); // Use visitor to collect all function `id` in the hashKeys - - std::vector hashKeyIdx; + bool found = false; + size_t hashKeyIdx; for (size_t i = 0; i < hashKeys.size(); ++i) { - auto* key = hashKeys[i]; - if (key->kind() != Expression::Kind::kFunctionCall) { + auto* hashKey = hashKeys[i]; + if (hashKey->kind() != Expression::Kind::kFunctionCall) { continue; } - auto* func = static_cast(key); - if (func->name() != "id" || func->name() != "_joinkey") { + auto* func = static_cast(hashKey); + if (func->name() != "id" && func->name() != "_joinkey") { continue; } auto& args = func->args()->args(); @@ -87,14 +76,22 @@ StatusOr OptimizeLeftJoinPredicateRule::transform( } auto& alias = static_cast(arg)->prop(); if (alias != avNodeAlias) continue; - // FIXME(jie): Must check if probe keys contain the same key - hashKeyIdx.emplace_back(i); + // Must check if probe keys contain the same key + if (*probeKeys[i] != *hashKey) { + return TransformResult::noTransform(); + } + if (found) { + return TransformResult::noTransform(); + } + hashKeyIdx = i; + found = true; } - if (hashKeyIdx.size() != 1) { + if (!found) { return TransformResult::noTransform(); } - std::vector prjIdx; + found = false; + size_t prjIdx; for (size_t i = 0; i < project->columns()->size(); ++i) { const auto* col = project->columns()->columns()[i]; if (col->expr()->kind() != Expression::Kind::kInputProperty) { @@ -102,51 +99,63 @@ StatusOr OptimizeLeftJoinPredicateRule::transform( } auto* inputProp = static_cast(col->expr()); if (inputProp->prop() != avNodeAlias) continue; - prjIdx.push_back(i); + if (found) { + return TransformResult::noTransform(); + } + prjIdx = i; + found = true; } - if (prjIdx.size() != 1) { + if (!found) { return TransformResult::noTransform(); } auto* pool = octx->qctx()->objPool(); - // Let the new project generate expr `none_direct_dst($-.tvEdgeAlias)`, and let the new left join - // use it as hash key + // Let the new project generate expr `none_direct_dst($-.tvEdgeAlias)`, + // and let the new left join use it as hash key auto* args = ArgumentList::make(pool); args->addArgument(InputPropertyExpression::make(pool, tvEdgeAlias)); auto* newPrjExpr = FunctionCallExpression::make(pool, "none_direct_dst", args); auto* newYieldColumns = pool->makeAndAdd(); for (size_t i = 0; i < project->columns()->size(); ++i) { - if (i == prjIdx[0]) { - newYieldColumns->addColumn(pool->makeAndAdd(newPrjExpr, newPrjExpr->toString())); + if (i == prjIdx) { + newYieldColumns->addColumn(new YieldColumn(newPrjExpr, avNodeAlias)); } else { - newYieldColumns->addColumn(project->columns()->columns()[i]); + newYieldColumns->addColumn(project->columns()->columns()[i]->clone().release()); } } auto* newProject = graph::Project::make(octx->qctx(), nullptr, newYieldColumns); - auto* newHashExpr = InputPropertyExpression::make(pool, newPrjExpr->toString()); + // $-.`avNodeAlias` + auto* newHashExpr = InputPropertyExpression::make(pool, avNodeAlias); std::vector newHashKeys; for (size_t i = 0; i < hashKeys.size(); ++i) { - if (i == hashKeyIdx[0]) { + if (i == hashKeyIdx) { newHashKeys.emplace_back(newHashExpr); } else { newHashKeys.emplace_back(hashKeys[i]); } } auto* newLeftJoin = - graph::BiLeftJoin::make(octx->qctx(), nullptr, nullptr, newHashKeys, probeKeys); + graph::HashLeftJoin::make(octx->qctx(), nullptr, nullptr, newHashKeys, probeKeys); TransformResult result; result.eraseAll = true; newProject->setInputVar(appendVertices->inputVar()); - newProject->setOutputVar(project->outputVar()); auto newProjectGroup = OptGroup::create(octx); auto* newProjectGroupNode = newProjectGroup->makeGroupNode(newProject); - newProjectGroupNode->setDeps(projectGroupNode->dependencies()); - - newLeftJoin->setDep(1, newProject); + newProjectGroupNode->setDeps(appendVerticesGroupNode->dependencies()); + + newLeftJoin->setLeftVar(leftJoin->leftInputVar()); + newLeftJoin->setRightVar(newProject->outputVar()); + newLeftJoin->setOutputVar(leftJoin->outputVar()); + // LOG the col names of newLeftJoin + auto& newLeftJoinColNames = newLeftJoin->colNames(); + LOG(ERROR) << "newLeftJoinColNames.size(): " << newLeftJoinColNames.size(); + for (auto& colName : newLeftJoinColNames) { + LOG(ERROR) << "colName: " << colName; + } auto* newLeftJoinGroupNode = OptGroupNode::create(octx, newLeftJoin, leftJoinGroup); newLeftJoinGroupNode->dependsOn(leftJoinGroupNode->dependencies()[0]); newLeftJoinGroupNode->dependsOn(newProjectGroup); diff --git a/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h b/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h index d075aefa7c1..bf2ef9495a0 100644 --- a/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h +++ b/src/graph/optimizer/rule/OptimizeLeftJoinPredicateRule.h @@ -10,23 +10,24 @@ namespace nebula { namespace opt { -/* -Before: - BiLeftJoin({id(v)}, id(v)) - / \ - ... Project - \ - AppendVertices(v) - \ - Traverse(e) - -After: - BiLeftJoin({id(v)}, none_direct_dst(e)) - / \ - ... Project - \ - Traverse(e) -*/ +// Before: +// HashLeftJoin({id(v)}, {id(v)}) +// / \ +// ... Project +// / \ +// AppendVertices(v) AppendVertices(v) +// / \ +// ... Traverse(e) +// +// After: +// HashLeftJoin({id(v)}, {$-.v}) +// / \ +// ... Project(..., none_direct_dst(e) AS v) +// / \ +// AppendVertices(v) Traverse(e) +// / +// ... +// class OptimizeLeftJoinPredicateRule final : public OptRule { public: const Pattern &pattern() const override; diff --git a/tests/tck/features/optimizer/OptimizeLeftJoinPredicateRule.feature b/tests/tck/features/optimizer/OptimizeLeftJoinPredicateRule.feature index 3386e1313d8..fac5dd3408e 100644 --- a/tests/tck/features/optimizer/OptimizeLeftJoinPredicateRule.feature +++ b/tests/tck/features/optimizer/OptimizeLeftJoinPredicateRule.feature @@ -18,47 +18,47 @@ Feature: Optimize left join predicate id(friendTeam) AS teamId, friendTeam.team.name AS teamName, numFriends - ORDER BY numFriends DESC - LIMIT 20 + ORDER BY teamName DESC """ - Then the result should be, in any order, with relax comparison: + Then the result should be, in order, with relax comparison: | teamId | teamName | numFriends | - | "Clippers" | "Clippers" | 0 | - | "Bulls" | "Bulls" | 0 | - | "Spurs" | "Spurs" | 0 | - | "Thunders" | "Thunders" | 0 | - | "Hornets" | "Hornets" | 0 | | "Warriors" | "Warriors" | 0 | - | "Hawks" | "Hawks" | 0 | - | "Kings" | "Kings" | 0 | - | "Magic" | "Magic" | 0 | | "Trail Blazers" | "Trail Blazers" | 0 | - | "Lakers" | "Lakers" | 0 | - | "Grizzlies" | "Grizzlies" | 0 | + | "Thunders" | "Thunders" | 0 | | "Suns" | "Suns" | 0 | + | "Spurs" | "Spurs" | 0 | | "Rockets" | "Rockets" | 0 | - | "Cavaliers" | "Cavaliers" | 0 | | "Raptors" | "Raptors" | 0 | + | "Pistons" | "Pistons" | 0 | + | "Magic" | "Magic" | 0 | + | "Lakers" | "Lakers" | 0 | + | "Kings" | "Kings" | 0 | + | "Jazz" | "Jazz" | 0 | + | "Hornets" | "Hornets" | 0 | + | "Heat" | "Heat" | 0 | + | "Hawks" | "Hawks" | 0 | + | "Grizzlies" | "Grizzlies" | 0 | + | "Clippers" | "Clippers" | 0 | | "Celtics" | "Celtics" | 0 | + | "Cavaliers" | "Cavaliers" | 0 | + | "Bulls" | "Bulls" | 0 | | "76ers" | "76ers" | 0 | - | "Heat" | "Heat" | 0 | - | "Jazz" | "Jazz" | 0 | And the execution plan should be: - | id | name | dependencies | profiling data | operator info | - | 21 | TopN | 18 | | | - | 18 | Project | 17 | | | - | 17 | Aggregate | 16 | | | - | 16 | BiLeftJoin | 10,15 | | | - | 10 | Dedup | 28 | | | - | 28 | Project | 22 | | | - | 22 | Filter | 26 | | | - | 26 | AppendVertices | 25 | | | - | 25 | Traverse | 24 | | | - | 24 | Traverse | 2 | | | - | 2 | Dedup | 1 | | | - | 1 | PassThrough | 3 | | | - | 3 | Start | | | | - | 15 | Project | 14 | | | - | 14 | Traverse | 12 | | | - | 12 | Traverse | 11 | | | - | 11 | Argument | | | | + | id | name | dependencies | operator info | + | 21 | Sort | 18 | | + | 18 | Project | 17 | | + | 17 | Aggregate | 16 | | + | 16 | HashLeftJoin | 10,15 | {"probeKeys": ["_joinkey($-.friendTeam)", "_joinkey($-.friend)"], "hashKeys": ["$-.friendTeam", "_joinkey($-.friend)"]} | + | 10 | Dedup | 28 | | + | 28 | Project | 22 | | + | 22 | Filter | 26 | | + | 26 | AppendVertices | 25 | | + | 25 | Traverse | 24 | | + | 24 | Traverse | 2 | | + | 2 | Dedup | 1 | | + | 1 | PassThrough | 3 | | + | 3 | Start | | | + | 15 | Project | 14 | {"columns": ["$-.friend AS friend", "$-.friend2 AS friend2", "none_direct_dst($-.__VAR_3) AS friendTeam"]} | + | 14 | Traverse | 12 | | + | 12 | Traverse | 11 | | + | 11 | Argument | | |