Skip to content

Commit

Permalink
Add window support (facebookincubator#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf authored and zhejiangxiaomai committed Mar 6, 2023
1 parent bc5ddd2 commit f5797a0
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 14 deletions.
6 changes: 3 additions & 3 deletions velox/functions/prestosql/window/Rank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ void registerRankInternal(
const std::vector<exec::WindowFunctionArg>& /*args*/,
const TypePtr& resultType,
velox::memory::MemoryPool* /*pool*/,
HashStringAllocator* /*stringAllocator*/)
-> std::unique_ptr<exec::WindowFunction> {
HashStringAllocator *
/*stringAllocator*/) -> std::unique_ptr<exec::WindowFunction> {
return std::make_unique<RankFunction<TRank, TResult>>(resultType);
});
}

void registerRank(const std::string& name) {
registerRankInternal<RankType::kRank, int64_t>(name, "bigint");
registerRankInternal<RankType::kRank, int32_t>(name, "integer");
}
void registerDenseRank(const std::string& name) {
registerRankInternal<RankType::kDenseRank, int64_t>(name, "bigint");
Expand Down
6 changes: 3 additions & 3 deletions velox/functions/prestosql/window/RowNumber.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class RowNumberFunction : public exec::WindowFunction {
vector_size_t resultOffset,
const VectorPtr& result) override {
int numRows = peerGroupStarts->size() / sizeof(vector_size_t);
auto* rawValues = result->asFlatVector<int64_t>()->mutableRawValues();
auto* rawValues = result->asFlatVector<int32_t>()->mutableRawValues();
for (int i = 0; i < numRows; i++) {
rawValues[resultOffset + i] = rowNumber_++;
}
Expand All @@ -65,8 +65,8 @@ void registerRowNumber(const std::string& name) {
const std::vector<exec::WindowFunctionArg>& /*args*/,
const TypePtr& /*resultType*/,
velox::memory::MemoryPool* /*pool*/,
HashStringAllocator* /*stringAllocator*/)
-> std::unique_ptr<exec::WindowFunction> {
HashStringAllocator *
/*stringAllocator*/) -> std::unique_ptr<exec::WindowFunction> {
return std::make_unique<RowNumberFunction>();
});
}
Expand Down
152 changes: 152 additions & 0 deletions velox/substrait/SubstraitToVeloxPlan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,155 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
childNode);
}

const core::WindowNode::Frame createWindowFrame(
const ::substrait::Expression_WindowFunction_Bound& lower_bound,
const ::substrait::Expression_WindowFunction_Bound& upper_bound,
const ::substrait::WindowType& type) {
core::WindowNode::Frame frame;
switch (type) {
case ::substrait::WindowType::ROWS:
frame.type = core::WindowNode::WindowType::kRows;
break;
case ::substrait::WindowType::RANGE:

frame.type = core::WindowNode::WindowType::kRange;
break;
default:
VELOX_FAIL(
"the window type only support ROWS and RANGE, and the input type is ",
type);
}

auto boundTypeConversion =
[](::substrait::Expression_WindowFunction_Bound boundType)
-> core::WindowNode::BoundType {
if (boundType.has_current_row()) {
return core::WindowNode::BoundType::kCurrentRow;
} else if (boundType.has_unbounded_following()) {
return core::WindowNode::BoundType::kUnboundedFollowing;
} else if (boundType.has_unbounded_preceding()) {
return core::WindowNode::BoundType::kUnboundedPreceding;
} else {
VELOX_FAIL("The BoundType is not supported.");
}
};
frame.startType = boundTypeConversion(lower_bound);
frame.startValue = nullptr;
frame.endType = boundTypeConversion(upper_bound);
frame.endValue = nullptr;
return frame;
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::WindowRel& windowRel) {
core::PlanNodePtr childNode;
if (windowRel.has_input()) {
childNode = toVeloxPlan(windowRel.input());
} else {
VELOX_FAIL("Child Rel is expected in WindowRel.");
}

const auto& inputType = childNode->outputType();

// Parse measures and get the window expressions.
// Each measure represents one window expression.
bool ignoreNullKeys = false;
std::vector<core::WindowNode::Function> windowNodeFunctions;
std::vector<std::string> windowColumnNames;

windowNodeFunctions.reserve(windowRel.measures().size());
for (const auto& smea : windowRel.measures()) {
const auto& windowFunction = smea.measure();
std::string funcName = subParser_->findVeloxFunction(
functionMap_, windowFunction.function_reference());
std::vector<std::shared_ptr<const core::ITypedExpr>> windowParams;
windowParams.reserve(windowFunction.arguments().size());
for (const auto& arg : windowFunction.arguments()) {
windowParams.emplace_back(
exprConverter_->toVeloxExpr(arg.value(), inputType));
}
auto windowVeloxType =
toVeloxType(subParser_->parseType(windowFunction.output_type())->type);
auto windowCall = std::make_shared<const core::CallTypedExpr>(
windowVeloxType, std::move(windowParams), funcName);
auto upperBound = windowFunction.upper_bound();
auto lowerBound = windowFunction.lower_bound();
auto type = windowFunction.window_type();

windowColumnNames.push_back(windowFunction.column_name());

windowNodeFunctions.push_back(
{std::move(windowCall),
createWindowFrame(lowerBound, upperBound, type),
ignoreNullKeys});
}

// Construct partitionKeys
std::vector<core::FieldAccessTypedExprPtr> partitionKeys;
const auto& partitions = windowRel.partition_expressions();
partitionKeys.reserve(partitions.size());
for (const auto& partition : partitions) {
auto expression = exprConverter_->toVeloxExpr(partition, inputType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the partition key in Window Operator only support field")

partitionKeys.emplace_back(
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
expression));
}

std::vector<core::FieldAccessTypedExprPtr> sortingKeys;
std::vector<core::SortOrder> sortingOrders;

const auto& sorts = windowRel.sorts();
sortingKeys.reserve(sorts.size());
sortingOrders.reserve(sorts.size());

for (const auto& sort : sorts) {
switch (sort.direction()) {
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
sortingOrders.emplace_back(core::kAscNullsFirst);
break;
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
sortingOrders.emplace_back(core::kAscNullsLast);
break;
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
sortingOrders.emplace_back(core::kDescNullsFirst);
break;
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
sortingOrders.emplace_back(core::kDescNullsLast);
break;
default:
VELOX_FAIL("Sort direction is not support in WindowRel");
}

if (sort.has_expr()) {
auto expression = exprConverter_->toVeloxExpr(sort.expr(), inputType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the sorting key in Window Operator only support field")

sortingKeys.emplace_back(
std::dynamic_pointer_cast<const core::FieldAccessTypedExpr>(
expression));
}
}

return std::make_shared<core::WindowNode>(
nextPlanNodeId(),
partitionKeys,
sortingKeys,
sortingOrders,
windowColumnNames,
windowNodeFunctions,
childNode);
}

core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
const ::substrait::SortRel& sortRel) {
auto childNode = convertSingleInput<::substrait::SortRel>(sortRel);
Expand Down Expand Up @@ -970,6 +1119,9 @@ core::PlanNodePtr SubstraitVeloxPlanConverter::toVeloxPlan(
if (sRel.has_fetch()) {
return toVeloxPlan(sRel.fetch());
}
if (sRel.has_window()) {
return toVeloxPlan(sRel.window());
}
VELOX_NYI("Substrait conversion not supported for Rel.");
}

Expand Down
3 changes: 3 additions & 0 deletions velox/substrait/SubstraitToVeloxPlan.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class SubstraitVeloxPlanConverter {
/// Used to convert Substrait ExpandRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::ExpandRel& sExpand);

/// Used to convert Substrait SortRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::WindowRel& sWindow);

/// Used to convert Substrait JoinRel into Velox PlanNode.
core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& sJoin);

Expand Down
136 changes: 136 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,139 @@ bool SubstraitToVeloxPlanValidator::validate(
return true;
}

bool validateBoundType(::substrait::Expression_WindowFunction_Bound boundType) {
switch (boundType.kind_case()) {
case ::substrait::Expression_WindowFunction_Bound::kUnboundedFollowing:
case ::substrait::Expression_WindowFunction_Bound::kUnboundedPreceding:
case ::substrait::Expression_WindowFunction_Bound::kCurrentRow:
break;
default:
std::cout << "The Bound Type is not supported. "
<< "\n";
return false;
}
return true;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::WindowRel& sWindow) {
if (sWindow.has_input() && !validate(sWindow.input())) {
return false;
}

// Get and validate the input types from extension.
if (!sWindow.has_advanced_extension()) {
std::cout << "Input types are expected in WindowRel." << std::endl;
return false;
}
const auto& extension = sWindow.advanced_extension();
std::vector<TypePtr> types;
if (!validateInputTypes(extension, types)) {
std::cout << "Validation failed for input types in WindowRel." << std::endl;
return false;
}

int32_t inputPlanNodeId = 0;
std::vector<std::string> names;
names.reserve(types.size());
for (auto colIdx = 0; colIdx < types.size(); colIdx++) {
names.emplace_back(subParser_->makeNodeName(inputPlanNodeId, colIdx));
}
auto rowType = std::make_shared<RowType>(std::move(names), std::move(types));

// Validate WindowFunction
std::vector<std::string> funcSpecs;
funcSpecs.reserve(sWindow.measures().size());
for (const auto& smea : sWindow.measures()) {
try {
const auto& windowFunction = smea.measure();
funcSpecs.emplace_back(
planConverter_->findFuncSpec(windowFunction.function_reference()));
toVeloxType(subParser_->parseType(windowFunction.output_type())->type);
for (const auto& arg : windowFunction.arguments()) {
auto typeCase = arg.value().rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
case ::substrait::Expression::RexTypeCase::kLiteral:
break;
default:
std::cout << "Only field is supported in window functions."
<< std::endl;
return false;
}
}
// Validate BoundType and Frame Type
switch (windowFunction.window_type()) {
case ::substrait::WindowType::ROWS:
case ::substrait::WindowType::RANGE:
break;
default:
VELOX_FAIL(
"the window type only support ROWS and RANGE, and the input type is ",
windowFunction.window_type());
}

validateBoundType(windowFunction.upper_bound());
validateBoundType(windowFunction.lower_bound());

} catch (const VeloxException& err) {
std::cout << "Validation failed for window function due to: "
<< err.message() << std::endl;
return false;
}
}

// Validate groupby expression
const auto& groupByExprs = sWindow.partition_expressions();
std::vector<std::shared_ptr<const core::ITypedExpr>> expressions;
expressions.reserve(groupByExprs.size());
try {
for (const auto& expr : groupByExprs) {
expressions.emplace_back(exprConverter_->toVeloxExpr(expr, rowType));
}
// Try to compile the expressions. If there is any unregistred funciton or
// mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
} catch (const VeloxException& err) {
std::cout << "Validation failed for expression in ProjectRel due to:"
<< err.message() << std::endl;
return false;
}

// Validate Sort expression
const auto& sorts = sWindow.sorts();
for (const auto& sort : sorts) {
switch (sort.direction()) {
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_ASC_NULLS_LAST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_FIRST:
case ::substrait::SortField_SortDirection_SORT_DIRECTION_DESC_NULLS_LAST:
break;
default:
return false;
}

if (sort.has_expr()) {
try {
auto expression = exprConverter_->toVeloxExpr(sort.expr(), rowType);
auto expr_field =
dynamic_cast<const core::FieldAccessTypedExpr*>(expression.get());
VELOX_CHECK(
expr_field != nullptr,
" the sorting key in Sort Operator only support field")

exec::ExprSet exprSet({std::move(expression)}, execCtx_);
} catch (const VeloxException& err) {
std::cout << "Validation failed for expression in SortRel due to:"
<< err.message() << std::endl;
return false;
}
}
}

return true;
}

bool SubstraitToVeloxPlanValidator::validate(
const ::substrait::SortRel& sSort) {
if (sSort.has_input() && !validate(sSort.input())) {
Expand Down Expand Up @@ -582,6 +715,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& sRel) {
if (sRel.has_fetch()) {
return validate(sRel.fetch());
}
if (sRel.has_window()) {
return validate(sRel.window());
}
return false;
}

Expand Down
3 changes: 3 additions & 0 deletions velox/substrait/SubstraitToVeloxPlanValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ class SubstraitToVeloxPlanValidator {
/// Used to validate whether the computing of this Sort is supported.
bool validate(const ::substrait::SortRel& sSort);

/// Used to validate whether the computing of this Window is supported.
bool validate(const ::substrait::WindowRel& sWindow);

/// Used to validate whether the computing of this Aggregation is supported.
bool validate(const ::substrait::AggregateRel& sAgg);

Expand Down
Loading

0 comments on commit f5797a0

Please sign in to comment.