Skip to content

Commit

Permalink
[OPPRO-10] Enable hash join in Substrait-to-Velox conversion (faceboo…
Browse files Browse the repository at this point in the history
…kincubator#9)

* hash join

* remove extra projection
  • Loading branch information
marin-ma authored and zhejiangxiaomai committed Aug 18, 2022
1 parent c8c32a7 commit 5854284
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 24 deletions.
2 changes: 1 addition & 1 deletion velox/substrait/SubstraitToVeloxExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class SubstraitVeloxExprConverter {
/// subParser: A Substrait parser used to convert Substrait representations
/// into recognizable representations. functionMap: A pre-constructed map
/// storing the relations between the function id and the function name.
SubstraitVeloxExprConverter(
explicit SubstraitVeloxExprConverter(
const std::unordered_map<uint64_t, std::string>& functionMap)
: functionMap_(functionMap) {}

Expand Down
162 changes: 143 additions & 19 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,100 @@ VectorPtr setVectorFromVariants(
}
} // namespace

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::AggregateRel& aggRel,
memory::MemoryPool* pool) {
core::PlanNodePtr childNode;
if (aggRel.has_input()) {
childNode = toVeloxPlan(aggRel.input(), pool);
std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::JoinRel& sJoin) {
if (!sJoin.has_left()) {
VELOX_FAIL("Left Rel is expected in JoinRel.");
}
if (!sJoin.has_right()) {
VELOX_FAIL("Right Rel is expected in JoinRel.");
}

auto leftNode = toVeloxPlan(sJoin.left());
auto rightNode = toVeloxPlan(sJoin.right());

auto outputSize =
leftNode->outputType()->size() + rightNode->outputType()->size();
std::vector<std::string> outputNames;
std::vector<std::shared_ptr<const Type>> outputTypes;
outputNames.reserve(outputSize);
outputTypes.reserve(outputSize);
for (const auto& node : {leftNode, rightNode}) {
const auto& names = node->outputType()->names();
outputNames.insert(outputNames.end(), names.begin(), names.end());
const auto& types = node->outputType()->children();
outputTypes.insert(outputTypes.end(), types.begin(), types.end());
}
auto outputRowType = std::make_shared<const RowType>(
std::move(outputNames), std::move(outputTypes));

// extract join keys from join expression
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
extractJoinKeys(sJoin.expression(), leftExprs, rightExprs);
VELOX_CHECK_EQ(leftExprs.size(), rightExprs.size());
size_t numKeys = leftExprs.size();

std::vector<std::shared_ptr<const core::FieldAccessTypedExpr>> leftKeys,
rightKeys;
leftKeys.reserve(numKeys);
rightKeys.reserve(numKeys);
for (size_t i = 0; i < numKeys; ++i) {
leftKeys.emplace_back(
exprConverter_->toVeloxExpr(*leftExprs[i], outputRowType));
rightKeys.emplace_back(
exprConverter_->toVeloxExpr(*rightExprs[i], outputRowType));
}

std::shared_ptr<const core::ITypedExpr> filter;
if (sJoin.has_post_join_filter()) {
filter =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), outputRowType);
}

// Map join type
core::JoinType joinType;
switch (sJoin.type()) {
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_INNER:
joinType = core::JoinType::kInner;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_OUTER:
joinType = core::JoinType::kFull;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_LEFT:
joinType = core::JoinType::kLeft;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_RIGHT:
joinType = core::JoinType::kRight;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_SEMI:
joinType = core::JoinType::kSemi;
break;
case ::substrait::JoinRel_JoinType::JoinRel_JoinType_JOIN_TYPE_ANTI:
joinType = core::JoinType::kAnti;
break;
default:
VELOX_NYI("Unsupported Join type: {}", sJoin.type());
}

// Create join node
return std::make_shared<core::HashJoinNode>(
nextPlanNodeId(),
joinType,
leftKeys,
rightKeys,
filter,
leftNode,
rightNode,
outputRowType);
}

std::shared_ptr<const core::PlanNode> SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::AggregateRel& sAgg) {
std::shared_ptr<const core::PlanNode> childNode;
if (sAgg.has_input()) {
childNode = toVeloxPlan(sAgg.input());

} else {
VELOX_FAIL("Child Rel is expected in AggregateRel.");
}
Expand Down Expand Up @@ -549,22 +637,22 @@ SubstraitVeloxPlanConverter::toVeloxAggWithRowConstruct(
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::Rel& rel, memory::MemoryPool* pool) {
if (rel.has_aggregate()) {
return toVeloxPlan(rel.aggregate(), pool);
const ::substrait::Rel& sRel) {
if (sRel.has_aggregate()) {
return toVeloxPlan(sRel.aggregate());
}
if (rel.has_project()) {
return toVeloxPlan(rel.project(), pool);
if (sRel.has_project()) {
return toVeloxPlan(sRel.project());
}
if (rel.has_filter()) {
return toVeloxPlan(rel.filter(), pool);
if (sRel.has_filter()) {
return toVeloxPlan(sRel.filter());
}
if (rel.has_read()) {
auto splitInfo = std::make_shared<SplitInfo>();

auto planNode = toVeloxPlan(rel.read(), pool, splitInfo);
splitInfoMap_[planNode->id()] = splitInfo;
return planNode;
if (sRel.has_join()) {
return toVeloxPlan(sRel.join());
}
if (sRel.has_read()) {
return toVeloxPlan(
sRel.read(), partitionIndex_, paths_, starts_, lengths_);
}
VELOX_NYI("Substrait conversion not supported for Rel.");
}
Expand Down Expand Up @@ -913,4 +1001,40 @@ const std::string& SubstraitVeloxPlanConverter::findFunction(
return substraitParser_->findFunctionSpec(functionMap_, id);
}

void SubstraitVeloxPlanConverter::extractJoinKeys(
const ::substrait::Expression& joinExpression,
std::vector<const ::substrait::Expression::FieldReference*>& leftExprs,
std::vector<const ::substrait::Expression::FieldReference*>& rightExprs) {
std::vector<const ::substrait::Expression*> expressions;
expressions.push_back(&joinExpression);
while (!expressions.empty()) {
auto visited = expressions.back();
expressions.pop_back();
if (visited->rex_type_case() ==
::substrait::Expression::RexTypeCase::kScalarFunction) {
const auto& funcName =
subParser_->getSubFunctionName(subParser_->findVeloxFunction(
functionMap_, visited->scalar_function().function_reference()));
const auto& args = visited->scalar_function().args();
if (funcName == "and") {
expressions.push_back(&args[0]);
expressions.push_back(&args[1]);
} else if (funcName == "equal") {
VELOX_CHECK(std::all_of(
args.cbegin(),
args.cend(),
[](const ::substrait::Expression& arg) {
return arg.has_selection();
}));
leftExprs.push_back(&args[0].selection());
rightExprs.push_back(&args[1].selection());
}
} else {
VELOX_FAIL(
"Unable to parse from join expression: {}",
joinExpression.DebugString());
}
}
}

} // namespace facebook::velox::substrait
21 changes: 17 additions & 4 deletions velox/substrait/SubstraitToVeloxPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ class SubstraitVeloxPlanConverter {
dwio::common::FileFormat format;
};

/// Convert Substrait AggregateRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(
const ::substrait::AggregateRel& aggRel,
memory::MemoryPool* pool);
/// Used to convert Substrait JoinRel into Velox PlanNode.
std::shared_ptr<const core::PlanNode> toVeloxPlan(
const ::substrait::JoinRel& sJoin);

/// Used to convert Substrait AggregateRel into Velox PlanNode.
std::shared_ptr<const core::PlanNode> toVeloxPlan(
const ::substrait::AggregateRel& sAgg);

/// Convert Substrait ProjectRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(
Expand Down Expand Up @@ -141,6 +144,16 @@ class SubstraitVeloxPlanConverter {
/// Used to find the function specification in the constructed function map.
std::string findFuncSpec(uint64_t id);

/// Extract join keys from joinExpression.
/// joinExpression is a boolean condition that describes whether each record
/// from the left set “match” the record from the right set. The condition
/// must only include the following operations: AND, ==, field references.
/// Field references correspond to the direct output order of the data.
void extractJoinKeys(
const ::substrait::Expression& joinExpression,
std::vector<const ::substrait::Expression::FieldReference*>& leftExprs,
std::vector<const ::substrait::Expression::FieldReference*>& rightExprs);

private:
/// The Partition index.
u_int32_t partitionIndex_;
Expand Down
72 changes: 72 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,75 @@ bool SubstraitToVeloxPlanValidator::validate(
return false;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::JoinRel& sJoin) {
if (sJoin.has_left() && !validate(sJoin.left())) {
return false;
}
if (sJoin.has_right() && !validate(sJoin.right())) {
return false;
}

switch (sJoin.type()) {
case ::substrait::JoinRel_JoinType_JOIN_TYPE_INNER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_OUTER:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_LEFT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_SEMI:
case ::substrait::JoinRel_JoinType_JOIN_TYPE_ANTI:
break;
default:
return false;
}

// Validate input types.
if (!sJoin.has_advanced_extension()) {
std::cout << "Input types are expected in JoinRel." << std::endl;
return false;
}

const auto& extension = sJoin.advanced_extension();
std::vector<TypePtr> types;
if (!validateInputTypes(extension, types)) {
std::cout << "Validation failed for input types in JoinRel" << std::endl;
return false;
}

int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));

if (sJoin.has_expression()) {
std::vector<const ::substrait::Expression::FieldReference*> leftExprs,
rightExprs;
try {
planConverter_->extractJoinKeys(
sJoin.expression(), leftExprs, rightExprs);
} catch (const VeloxException& err) {
std::cout << "Validation failed for expression in JoinRel due to:"
<< err.message() << std::endl;
return false;
}
}

if (sJoin.has_post_join_filter()) {
try {
auto expression =
exprConverter_->toVeloxExpr(sJoin.post_join_filter(), rowType);
exec::ExprSet exprSet({std::move(expression)}, &execCtx_);
} catch (const VeloxException& err) {
std::cout << "Validation failed for expression in ProjectRel due to:"
<< err.message() << std::endl;
return false;
}
}
return true;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::AggregateRel& sAgg) {
if (sAgg.has_input() && !validate(sAgg.input())) {
Expand Down Expand Up @@ -304,6 +373,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& sRel) {
if (sRel.has_filter()) {
return validate(sRel.filter());
}
if (sRel.has_join()) {
return validate(sRel.join());
}
if (sRel.has_read()) {
return validate(sRel.read());
}
Expand Down
3 changes: 3 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class SubstraitToVeloxPlanValidator {
/// Used to validate whether the computing of this Filter is supported.
bool validate(const ::substrait::FilterRel& sFilter);

/// Used to validate Join.
bool validate(const ::substrait::JoinRel& sJoin);

/// Used to validate whether the computing of this Read is supported.
bool validate(const ::substrait::ReadRel& sRead);

Expand Down

0 comments on commit 5854284

Please sign in to comment.