From 962104aa44ef21daf099b5579077002bb5ad9db1 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Wed, 11 May 2022 15:03:45 +0800 Subject: [PATCH] [OPPRO-10] Enable hash join in Substrait-to-Velox conversion (#9) * hash join * remove extra projection --- velox/substrait/SubstraitToVeloxPlan.cpp | 154 ++++++++++++++++-- velox/substrait/SubstraitToVeloxPlan.h | 17 +- .../SubstraitToVeloxPlanValidator.cpp | 72 ++++++++ .../substrait/SubstraitToVeloxPlanValidator.h | 3 + 4 files changed, 229 insertions(+), 17 deletions(-) diff --git a/velox/substrait/SubstraitToVeloxPlan.cpp b/velox/substrait/SubstraitToVeloxPlan.cpp index db98a984fd01..5515a66a4526 100644 --- a/velox/substrait/SubstraitToVeloxPlan.cpp +++ b/velox/substrait/SubstraitToVeloxPlan.cpp @@ -110,9 +110,97 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::processEmit( } core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( - const ::substrait::AggregateRel& aggRel) { + const ::substrait::JoinRel& sJoin) { + if (!sJoin.has_left()) { + VELOX_FAIL("Left Rel is expected in JoinRel."); + } + if (!sJoin.has_right()) { + VELOX_FAIL("Right Rel is expected in JoinRel."); + } + + auto leftNode = toVeloxPlan(sJoin.left()); + auto rightNode = toVeloxPlan(sJoin.right()); + + auto outputSize = + leftNode->outputType()->size() + rightNode->outputType()->size(); + std::vector outputNames; + std::vector> outputTypes; + outputNames.reserve(outputSize); + outputTypes.reserve(outputSize); + for (const auto& node : {leftNode, rightNode}) { + const auto& names = node->outputType()->names(); + outputNames.insert(outputNames.end(), names.begin(), names.end()); + const auto& types = node->outputType()->children(); + outputTypes.insert(outputTypes.end(), types.begin(), types.end()); + } + auto outputRowType = std::make_shared( + std::move(outputNames), std::move(outputTypes)); + + // extract join keys from join expression + std::vector leftExprs, + rightExprs; + extractJoinKeys(sJoin.expression(), leftExprs, rightExprs); + VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size()); + size_t numKeys = leftExprs.size(); + + std::vector> leftKeys, + rightKeys; + leftKeys.reserve(numKeys); + rightKeys.reserve(numKeys); + for (size_t i = 0; i < numKeys; ++i) { + leftKeys.emplace_back( + exprConverter_->toVeloxExpr(*leftExprs[i], outputRowType)); + rightKeys.emplace_back( + exprConverter_->toVeloxExpr(*rightExprs[i], outputRowType)); + } + + std::shared_ptr filter; + if (sJoin.has_post_join_filter()) { + filter = + exprConverter_->toVeloxExpr(sJoin.post_join_filter(), outputRowType); + } + + // Map join type + core::JoinType joinType; + switch (sJoin.type()) { + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER: + joinType = core::JoinType::kInner; + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER: + joinType = core::JoinType::kFull; + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT: + joinType = core::JoinType::kLeft; + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT: + joinType = core::JoinType::kRight; + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_SEMI: + joinType = core::JoinType::kLeftSemi; + break; + case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: + joinType = core::JoinType::kNullAwareAnti; + break; + default: + VELOX_NYI("Unsupported Join type: {}", sJoin.type()); + } + + // Create join node + return std::make_shared( + nextPlanNodeId(), + joinType, + leftKeys, + rightKeys, + filter, + leftNode, + rightNode, + outputRowType); +} + +core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( + const ::substrait::AggregateRel& sAgg) { auto childNode = convertSingleInput<::substrait::AggregateRel>(aggRel); - core::AggregationNode::Step aggStep = toAggregationStep(aggRel); + core::AggregationNode::Step aggStep = toAggregationStep(sAgg); return toVeloxAgg(sAgg, childNode, aggStep); } @@ -583,22 +671,24 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( } core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan( - const ::substrait::Rel& rel) { - if (rel.has_aggregate()) { - return toVeloxPlan(rel.aggregate()); + const ::substrait::Rel& sRel) { + if (sRel.has_aggregate()) { + return toVeloxPlan(sRel.aggregate()); } - if (rel.has_project()) { - return toVeloxPlan(rel.project()); + if (sRel.has_project()) { + return toVeloxPlan(sRel.project()); } - if (rel.has_filter()) { - return toVeloxPlan(rel.filter()); + if (sRel.has_filter()) { + return toVeloxPlan(sRel.filter()); } - if (rel.has_read()) { - auto splitInfo = std::make_shared(); - - auto planNode = toVeloxPlan(rel.read(), splitInfo); - splitInfoMap_[planNode->id()] = splitInfo; - return planNode; + if (sRel.has_join()) { + return toVeloxPlan(sRel.join()); + } + if (sRel.has_read()) { + return toVeloxPlan(sRel.read()); + } + if (sRel.has_sort()) { + return toVeloxPlan(sRel.sort()); } if (rel.has_fetch()) { return toVeloxPlan(rel.fetch()); @@ -935,4 +1025,38 @@ const std::string& SubstraitVeloxPlanConverter::findFunction( return substraitParser_->findFunctionSpec(functionMap_, id); } +void SubstraitVeloxPlanConverter::extractJoinKeys( + const ::substrait::Expression& joinExpression, + std::vector& leftExprs, + std::vector& rightExprs) { + std::vector expressions; + expressions.push_back(&joinExpression); + while (!expressions.empty()) { + auto visited = expressions.back(); + expressions.pop_back(); + if (visited->rex_type_case() == + ::substrait::Expression::RexTypeCase::kScalarFunction) { + const auto& funcName = + subParser_->getSubFunctionName(subParser_->findVeloxFunction( + functionMap_, visited->scalar_function().function_reference())); + const auto& args = visited->scalar_function().args(); + if (funcName == "and") { + expressions.push_back(&args[0]); + expressions.push_back(&args[1]); + } else if (funcName == "equal") { + VELOX_CHECK(std::all_of( + args.cbegin(), args.cend(), [](const ::substrait::Expression& arg) { + return arg.has_selection(); + })); + leftExprs.push_back(&args[0].selection()); + rightExprs.push_back(&args[1].selection()); + } + } else { + VELOX_FAIL( + "Unable to parse from join expression: {}", + joinExpression.DebugString()); + } + } +} + } // namespace facebook::velox::substrait diff --git a/velox/substrait/SubstraitToVeloxPlan.h b/velox/substrait/SubstraitToVeloxPlan.h index 29c4bebcb1d8..47065a13c462 100644 --- a/velox/substrait/SubstraitToVeloxPlan.h +++ b/velox/substrait/SubstraitToVeloxPlan.h @@ -44,8 +44,11 @@ class SubstraitVeloxPlanConverter { dwio::common::FileFormat format; }; - /// Convert Substrait AggregateRel into Velox PlanNode. - core::PlanNodePtr toVeloxPlan(const ::substrait::AggregateRel& aggRel); + /// Used to convert Substrait JoinRel into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& sJoin); + + /// Used to convert Substrait AggregateRel into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::AggregateRel& sAgg); /// Convert Substrait ProjectRel into Velox PlanNode. core::PlanNodePtr toVeloxPlan(const ::substrait::ProjectRel& projectRel); @@ -163,6 +166,16 @@ class SubstraitVeloxPlanConverter { /// Used to find the function specification in the constructed function map. std::string findFuncSpec(uint64_t id); + /// Extract join keys from joinExpression. + /// joinExpression is a boolean condition that describes whether each record + /// from the left set “match” the record from the right set. The condition + /// must only include the following operations: AND, ==, field references. + /// Field references correspond to the direct output order of the data. + void extractJoinKeys( + const ::substrait::Expression& joinExpression, + std::vector& leftExprs, + std::vector& rightExprs); + private: /// The Partition index. u_int32_t partitionIndex_; diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.cpp b/velox/substrait/SubstraitToVeloxPlanValidator.cpp index 977b43a88402..cc412523fa95 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.cpp +++ b/velox/substrait/SubstraitToVeloxPlanValidator.cpp @@ -115,6 +115,75 @@ bool SubstraitToVeloxPlanValidator::validate( return false; } +bool SubstraitToVeloxPlanValidator::validate( + const ::substrait::JoinRel& sJoin) { + if (sJoin.has_left() && !validate(sJoin.left())) { + return false; + } + if (sJoin.has_right() && !validate(sJoin.right())) { + return false; + } + + switch (sJoin.type()) { + case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_SEMI: + case ::substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: + break; + default: + return false; + } + + // Validate input types. + if (!sJoin.has_advanced_extension()) { + std::cout << "Input types are expected in JoinRel." << std::endl; + return false; + } + + const auto& extension = sJoin.advanced_extension(); + std::vector types; + if (!validateInputTypes(extension, types)) { + std::cout << "Validation failed for input types in JoinRel" << std::endl; + return false; + } + + int32_t inputPlanNodeId = 0; + std::vector names; + names.reserve(types.size()); + for (auto colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx)); + } + auto rowType = std::make_shared(std::move(names), std::move(types)); + + if (sJoin.has_expression()) { + std::vector leftExprs, + rightExprs; + try { + planConverter_->extractJoinKeys( + sJoin.expression(), leftExprs, rightExprs); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in JoinRel due to:" + << err.message() << std::endl; + return false; + } + } + + if (sJoin.has_post_join_filter()) { + try { + auto expression = + exprConverter_->toVeloxExpr(sJoin.post_join_filter(), rowType); + exec::ExprSet exprSet({std::move(expression)}, &execCtx_); + } catch (const VeloxException& err) { + std::cout << "Validation failed for expression in ProjectRel due to:" + << err.message() << std::endl; + return false; + } + } + return true; +} + bool SubstraitToVeloxPlanValidator::validate( const ::substrait::AggregateRel& sAgg) { if (sAgg.has_input() && !validate(sAgg.input())) { @@ -304,6 +373,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& sRel) { if (sRel.has_filter()) { return validate(sRel.filter()); } + if (sRel.has_join()) { + return validate(sRel.join()); + } if (sRel.has_read()) { return validate(sRel.read()); } diff --git a/velox/substrait/SubstraitToVeloxPlanValidator.h b/velox/substrait/SubstraitToVeloxPlanValidator.h index 705b8ebec3af..52164696685b 100644 --- a/velox/substrait/SubstraitToVeloxPlanValidator.h +++ b/velox/substrait/SubstraitToVeloxPlanValidator.h @@ -36,6 +36,9 @@ class SubstraitToVeloxPlanValidator { /// Used to validate whether the computing of this Filter is supported. bool validate(const ::substrait::FilterRel& sFilter); + /// Used to validate Join. + bool validate(const ::substrait::JoinRel& sJoin); + /// Used to validate whether the computing of this Read is supported. bool validate(const ::substrait::ReadRel& sRead);