Skip to content

Commit

Permalink
Fix semi join output type and support existence join (facebookincubat…
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo authored and zhejiangxiaomai committed Apr 20, 2023
1 parent cb8f366 commit 0b7ae5b
Showing 1 changed file with 145 additions and 57 deletions.
202 changes: 145 additions & 57 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,106 @@ const std::string sNot = "not";
// Substrait types.
const std::string sI32 = "i32";
const std::string sI64 = "i64";

/// @brief Return whether a config is set as true in AdvancedExtension
/// optimization.
/// @param extension Substrait advanced extension.
/// @param config the key string of a config.
/// @return Whether the config is set as true.
bool configSetInOptimization(
const ::substrait::extensions::AdvancedExtension& extension,
const std::string& config) {
if (extension.has_optimization()) {
std::string msg = extension.optimization().value();
std::size_t pos = msg.find(config);
if ((pos != std::string::npos) &&
(msg.substr(pos + config.size(), 1) == "1")) {
return true;
}
}
return false;
}

/// @brief Get the input type from both sides of join.
/// @param leftNode the plan node of left side.
/// @param rightNode the plan node of right side.
/// @return the input type.
RowTypePtr getJoinInputType(
const core::PlanNodePtr& leftNode,
const core::PlanNodePtr& rightNode) {
auto outputSize =
leftNode->outputType()->size() + rightNode->outputType()->size();
std::vector<std::string> outputNames;
std::vector<std::shared_ptr<const Type>> outputTypes;
outputNames.reserve(outputSize);
outputTypes.reserve(outputSize);
for (const auto& node : {leftNode, rightNode}) {
const auto& names = node->outputType()->names();
outputNames.insert(outputNames.end(), names.begin(), names.end());
const auto& types = node->outputType()->children();
outputTypes.insert(outputTypes.end(), types.begin(), types.end());
}
return std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));
}

/// @brief Get the direct output type of join.
/// @param leftNode the plan node of left side.
/// @param rightNode the plan node of right side.
/// @param joinType the join type.
/// @return the output type.
RowTypePtr getJoinOutputType(
const core::PlanNodePtr& leftNode,
const core::PlanNodePtr& rightNode,
const core::JoinType& joinType) {
// Decide output type.
// Output of right semi join cannot include columns from the left side.
bool outputMayIncludeLeftColumns =
!(core::isRightSemiFilterJoin(joinType) ||
core::isRightSemiProjectJoin(joinType));

// Output of left semi and anti joins cannot include columns from the right
// side.
bool outputMayIncludeRightColumns =
!(core::isLeftSemiFilterJoin(joinType) ||
core::isLeftSemiProjectJoin(joinType) || core::isAntiJoin(joinType) ||
core::isNullAwareAntiJoin(joinType));

if (outputMayIncludeLeftColumns && outputMayIncludeRightColumns) {
return getJoinInputType(leftNode, rightNode);
}

if (outputMayIncludeLeftColumns) {
if (core::isLeftSemiProjectJoin(joinType)) {
auto outputSize = leftNode->outputType()->size() + 1;
std::vector<std::string> outputNames = leftNode->outputType()->names();
std::vector<std::shared_ptr<const Type>> outputTypes =
leftNode->outputType()->children();
outputNames.emplace_back("exists");
outputTypes.emplace_back(BOOLEAN());
return std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));
} else {
return leftNode->outputType();
}
}

if (outputMayIncludeRightColumns) {
if (core::isRightSemiProjectJoin(joinType)) {
auto outputSize = rightNode->outputType()->size() + 1;
std::vector<std::string> outputNames = rightNode->outputType()->names();
std::vector<std::shared_ptr<const Type>> outputTypes =
rightNode->outputType()->children();
outputNames.emplace_back("exists");
outputTypes.emplace_back(BOOLEAN());
return std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));
} else {
return rightNode->outputType();
}
}
VELOX_FAIL("Output should include left or right columns.");
}
} // namespace

core::PlanNodePtr SubstraitVeloxPlanConverter::processEmit(
Expand Down Expand Up @@ -158,46 +258,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
auto leftNode = toVeloxPlan(sJoin.left());
auto rightNode = toVeloxPlan(sJoin.right());

auto outputSize =
leftNode->outputType()->size() + rightNode->outputType()->size();
std::vector<std::string> outputNames;
std::vector<std::shared_ptr<const Type>> outputTypes;
outputNames.reserve(outputSize);
outputTypes.reserve(outputSize);
for (const auto& node : {leftNode, rightNode}) {
const auto& names = node->outputType()->names();
outputNames.insert(outputNames.end(), names.begin(), names.end());
const auto& types = node->outputType()->children();
outputTypes.insert(outputTypes.end(), types.begin(), types.end());
}
auto outputRowType = std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));

// extract join keys from join expression
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
size_t numKeys = leftExprs.size();

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys,
rightKeys;
leftKeys.reserve(numKeys);
rightKeys.reserve(numKeys);
for (size_t i = 0; i < numKeys; ++i) {
leftKeys.emplace_back(
exprConverter_->toVeloxExpr(*leftExprs[i], outputRowType));
rightKeys.emplace_back(
exprConverter_->toVeloxExpr(*rightExprs[i], outputRowType));
}

std::shared_ptr<const core::ITypedExpr> filter;
if (sJoin.has_post_join_filter()) {
filter =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), outputRowType);
}

// Map join type
// Map join type.
core::JoinType joinType;
switch (sJoin.type()) {
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
Expand All @@ -213,25 +274,30 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
joinType = core::JoinType::kRight;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI:
joinType = core::JoinType::kLeftSemi;
// Determine the semi join type based on extracted information.
if (sJoin.has_advanced_extension() &&
configSetInOptimization(
sJoin.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kLeftSemiProject;
} else {
joinType = core::JoinType::kLeftSemi;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT_SEMI:
joinType = core::JoinType::kRightSemi;
// Determine the semi join type based on extracted information.
if (sJoin.has_advanced_extension() &&
configSetInOptimization(
sJoin.advanced_extension(), "isExistenceJoin=")) {
joinType = core::JoinType::kRightSemiProject;
} else {
joinType = core::JoinType::kRightSemi;
}
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI: {
// Determine the anti join type based on extracted information.
bool isNullAwareAntiJoin = false;
if (sJoin.has_advanced_extension() &&
sJoin.advanced_extension().has_optimization()) {
std::string msg = sJoin.advanced_extension().optimization().value();
std::string nullAwareKey = "isNullAwareAntiJoin=";
std::size_t pos = msg.find(nullAwareKey);
if ((pos != std::string::npos) &&
(msg.substr(pos + nullAwareKey.size(), 1) == "1")) {
isNullAwareAntiJoin = true;
}
}
if (isNullAwareAntiJoin) {
configSetInOptimization(
sJoin.advanced_extension(), "isNullAwareAntiJoin=")) {
joinType = core::JoinType::kNullAwareAnti;
} else {
joinType = core::JoinType::kAnti;
Expand All @@ -242,6 +308,31 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
VELOX_NYI("Unsupported Join type: {}", sJoin.type());
}

// extract join keys from join expression
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
size_t numKeys = leftExprs.size();

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys,
rightKeys;
leftKeys.reserve(numKeys);
rightKeys.reserve(numKeys);
auto inputRowType = getJoinInputType(leftNode, rightNode);
for (size_t i = 0; i < numKeys; ++i) {
leftKeys.emplace_back(
exprConverter_->toVeloxExpr(*leftExprs[i], inputRowType));
rightKeys.emplace_back(
exprConverter_->toVeloxExpr(*rightExprs[i], inputRowType));
}

std::shared_ptr<const core::ITypedExpr> filter;
if (sJoin.has_post_join_filter()) {
filter =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), inputRowType);
}

// Create join node
return std::make_shared<core::HashJoinNode>(
nextPlanNodeId(),
Expand All @@ -251,7 +342,7 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
filter,
leftNode,
rightNode,
outputRowType);
getJoinOutputType(leftNode, rightNode, joinType));
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
Expand Down Expand Up @@ -418,9 +509,6 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
childNode);
}

std::pair<
std::vector<core::FieldAccessTypedExprPtr>,
std::vector<core::SortOrder>>
SubstraitVeloxPlanConverter::processSortField(
const ::google::protobuf::RepeatedPtrField<::substrait::SortField>&
sortFields,
Expand Down

0 comments on commit 0b7ae5b

Please sign in to comment.