Skip to content

Commit

Permalink
[OPPRO-10] Enable hash join in Substrait-to-Velox conversion (faceboo…
Browse files Browse the repository at this point in the history
…kincubator#9)

* hash join

* remove extra projection
  • Loading branch information
marin-ma authored and zhejiangxiaomai committed Apr 20, 2023
1 parent d1771cf commit 962104a
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 17 deletions.
154 changes: 139 additions & 15 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,97 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::processEmit(
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::AggregateRel& aggRel) {
const ::substrait::JoinRel& sJoin) {
if (!sJoin.has_left()) {
VELOX_FAIL("Left Rel is expected in JoinRel.");
}
if (!sJoin.has_right()) {
VELOX_FAIL("Right Rel is expected in JoinRel.");
}

auto leftNode = toVeloxPlan(sJoin.left());
auto rightNode = toVeloxPlan(sJoin.right());

auto outputSize =
leftNode->outputType()->size() + rightNode->outputType()->size();
std::vector<std::string> outputNames;
std::vector<std::shared_ptr<const Type>> outputTypes;
outputNames.reserve(outputSize);
outputTypes.reserve(outputSize);
for (const auto& node : {leftNode, rightNode}) {
const auto& names = node->outputType()->names();
outputNames.insert(outputNames.end(), names.begin(), names.end());
const auto& types = node->outputType()->children();
outputTypes.insert(outputTypes.end(), types.begin(), types.end());
}
auto outputRowType = std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));

// extract join keys from join expression
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
size_t numKeys = leftExprs.size();

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys,
rightKeys;
leftKeys.reserve(numKeys);
rightKeys.reserve(numKeys);
for (size_t i = 0; i < numKeys; ++i) {
leftKeys.emplace_back(
exprConverter_->toVeloxExpr(*leftExprs[i], outputRowType));
rightKeys.emplace_back(
exprConverter_->toVeloxExpr(*rightExprs[i], outputRowType));
}

std::shared_ptr<const core::ITypedExpr> filter;
if (sJoin.has_post_join_filter()) {
filter =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), outputRowType);
}

// Map join type
core::JoinType joinType;
switch (sJoin.type()) {
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
joinType = core::JoinType::kInner;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER:
joinType = core::JoinType::kFull;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT:
joinType = core::JoinType::kLeft;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT:
joinType = core::JoinType::kRight;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_SEMI:
joinType = core::JoinType::kLeftSemi;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI:
joinType = core::JoinType::kNullAwareAnti;
break;
default:
VELOX_NYI("Unsupported Join type: {}", sJoin.type());
}

// Create join node
return std::make_shared<core::HashJoinNode>(
nextPlanNodeId(),
joinType,
leftKeys,
rightKeys,
filter,
leftNode,
rightNode,
outputRowType);
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::AggregateRel& sAgg) {
auto childNode = convertSingleInput<::substrait::AggregateRel>(aggRel);
core::AggregationNode::Step aggStep = toAggregationStep(aggRel);
core::AggregationNode::Step aggStep = toAggregationStep(sAgg);
return toVeloxAgg(sAgg, childNode, aggStep);
}

Expand Down Expand Up @@ -583,22 +671,24 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::Rel& rel) {
if (rel.has_aggregate()) {
return toVeloxPlan(rel.aggregate());
const ::substrait::Rel& sRel) {
if (sRel.has_aggregate()) {
return toVeloxPlan(sRel.aggregate());
}
if (rel.has_project()) {
return toVeloxPlan(rel.project());
if (sRel.has_project()) {
return toVeloxPlan(sRel.project());
}
if (rel.has_filter()) {
return toVeloxPlan(rel.filter());
if (sRel.has_filter()) {
return toVeloxPlan(sRel.filter());
}
if (rel.has_read()) {
auto splitInfo = std::make_shared<SplitInfo>();

auto planNode = toVeloxPlan(rel.read(), splitInfo);
splitInfoMap_[planNode->id()] = splitInfo;
return planNode;
if (sRel.has_join()) {
return toVeloxPlan(sRel.join());
}
if (sRel.has_read()) {
return toVeloxPlan(sRel.read());
}
if (sRel.has_sort()) {
return toVeloxPlan(sRel.sort());
}
if (rel.has_fetch()) {
return toVeloxPlan(rel.fetch());
Expand Down Expand Up @@ -935,4 +1025,38 @@ const std::string& SubstraitVeloxPlanConverter::findFunction(
return substraitParser_->findFunctionSpec(functionMap_, id);
}

void SubstraitVeloxPlanConverter::extractJoinKeys(
const ::substrait::Expression& joinExpression,
std::vector<const ::substrait::Expression::FieldReference*>& leftExprs,
std::vector<const ::substrait::Expression::FieldReference*>& rightExprs) {
std::vector<const ::substrait::Expression*> expressions;
expressions.push_back(&joinExpression);
while (!expressions.empty()) {
auto visited = expressions.back();
expressions.pop_back();
if (visited->rex_type_case() ==
::substrait::Expression::RexTypeCase::kScalarFunction) {
const auto& funcName =
subParser_->getSubFunctionName(subParser_->findVeloxFunction(
functionMap_, visited->scalar_function().function_reference()));
const auto& args = visited->scalar_function().args();
if (funcName == "and") {
expressions.push_back(&args[0]);
expressions.push_back(&args[1]);
} else if (funcName == "equal") {
VELOX_CHECK(std::all_of(
args.cbegin(), args.cend(), [](const ::substrait::Expression& arg) {
return arg.has_selection();
}));
leftExprs.push_back(&args[0].selection());
rightExprs.push_back(&args[1].selection());
}
} else {
VELOX_FAIL(
"Unable to parse from join expression: {}",
joinExpression.DebugString());
}
}
}

} // namespace facebook::velox::substrait
17 changes: 15 additions & 2 deletions velox/substrait/SubstraitToVeloxPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ class SubstraitVeloxPlanConverter {
dwio::common::FileFormat format;
};

/// Convert Substrait AggregateRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::AggregateRel& aggRel);
/// Used to convert Substrait JoinRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& sJoin);

/// Used to convert Substrait AggregateRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::AggregateRel& sAgg);

/// Convert Substrait ProjectRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::ProjectRel& projectRel);
Expand Down Expand Up @@ -163,6 +166,16 @@ class SubstraitVeloxPlanConverter {
/// Used to find the function specification in the constructed function map.
std::string findFuncSpec(uint64_t id);

/// Extract join keys from joinExpression.
/// joinExpression is a boolean condition that describes whether each record
/// from the left set “match” the record from the right set. The condition
/// must only include the following operations: AND, ==, field references.
/// Field references correspond to the direct output order of the data.
void extractJoinKeys(
const ::substrait::Expression& joinExpression,
std::vector<const ::substrait::Expression::FieldReference*>& leftExprs,
std::vector<const ::substrait::Expression::FieldReference*>& rightExprs);

private:
/// The Partition index.
u_int32_t partitionIndex_;
Expand Down
72 changes: 72 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,75 @@ bool SubstraitToVeloxPlanValidator::validate(
return false;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::JoinRel& sJoin) {
if (sJoin.has_left() && !validate(sJoin.left())) {
return false;
}
if (sJoin.has_right() && !validate(sJoin.right())) {
return false;
}

switch (sJoin.type()) {
case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_SEMI:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_ANTI:
break;
default:
return false;
}

// Validate input types.
if (!sJoin.has_advanced_extension()) {
std::cout << "Input types are expected in JoinRel." << std::endl;
return false;
}

const auto& extension = sJoin.advanced_extension();
std::vector<TypePtr> types;
if (!validateInputTypes(extension, types)) {
std::cout << "Validation failed for input types in JoinRel" << std::endl;
return false;
}

int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));

if (sJoin.has_expression()) {
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
try {
planConverter_->extractJoinKeys(
sJoin.expression(), leftExprs, rightExprs);
} catch (const VeloxException& err) {
std::cout << "Validation failed for expression in JoinRel due to:"
<< err.message() << std::endl;
return false;
}
}

if (sJoin.has_post_join_filter()) {
try {
auto expression =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), rowType);
exec::ExprSet exprSet({std::move(expression)}, &execCtx_);
} catch (const VeloxException& err) {
std::cout << "Validation failed for expression in ProjectRel due to:"
<< err.message() << std::endl;
return false;
}
}
return true;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::AggregateRel& sAgg) {
if (sAgg.has_input() && !validate(sAgg.input())) {
Expand Down Expand Up @@ -304,6 +373,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& sRel) {
if (sRel.has_filter()) {
return validate(sRel.filter());
}
if (sRel.has_join()) {
return validate(sRel.join());
}
if (sRel.has_read()) {
return validate(sRel.read());
}
Expand Down
3 changes: 3 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class SubstraitToVeloxPlanValidator {
/// Used to validate whether the computing of this Filter is supported.
bool validate(const ::substrait::FilterRel& sFilter);

/// Used to validate Join.
bool validate(const ::substrait::JoinRel& sJoin);

/// Used to validate whether the computing of this Read is supported.
bool validate(const ::substrait::ReadRel& sRead);

Expand Down

0 comments on commit 962104a

Please sign in to comment.