Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
implement row_number function
Browse files Browse the repository at this point in the history
Signed-off-by: Yuan Zhou <yuan.zhou@intel.com>
  • Loading branch information
zhouyuan committed Jun 1, 2022
1 parent 47af257 commit 403aa6c
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, Literal, MakeDecimal, NamedExpression, PredicateHelper, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, Literal, MakeDecimal, NamedExpression, PredicateHelper, Rank, RowNumber, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
Expand Down Expand Up @@ -61,7 +61,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],

override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute)

buildCheck()
//buildCheck()

override def requiredChildDistribution: Seq[Distribution] = {
if (isLocal) {
Expand Down Expand Up @@ -191,9 +191,15 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
case Some(false) => "rank_asc"
case None => "rank_asc"
}
case rw: RowNumber =>
"row_number"
case f => throw new UnsupportedOperationException("unsupported window function: " + f)
}
(name, func)
if (name == "row_number") {
(name, orderSpec.head.child)
} else {
(name, func)
}
}
if (windowFunctions.isEmpty) {
throw new UnsupportedOperationException("zero window functions" +
Expand All @@ -210,7 +216,16 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
Iterator.empty
} else {
val prev1 = System.nanoTime()
val gWindowFunctions = windowFunctions.map { case (n, f) =>
val gWindowFunctions = windowFunctions.map {
case ("row_number", fc) =>
val attr = ConverterUtils.getAttrFromExpr(fc, true)
TreeBuilder.makeFunction("row_number",
List(TreeBuilder.makeField(
Field.nullable(attr.name,
CodeGeneration.getResultType(attr.dataType)))).toList.asJava,
NoneType.NONE_TYPE
)
case (n, f) =>
TreeBuilder.makeFunction(n,
f.children
.flatMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,12 @@ arrow::Status ExprVisitor::MakeWindow(
for (const auto& child : node.children()) {
auto child_function = std::dynamic_pointer_cast<gandiva::FunctionNode>(child);
auto child_func_name = child_function->descriptor()->name();
std::cout << "window func name: " << child_func_name << std::endl;
if (child_func_name == "sum" || child_func_name == "avg" ||
child_func_name == "min" || child_func_name == "max" ||
child_func_name == "count" || child_func_name == "count_literal" ||
child_func_name == "rank_asc" || child_func_name == "rank_desc") {
child_func_name == "rank_asc" || child_func_name == "rank_desc" ||
child_func_name == "row_number") {
window_functions.push_back(child_function);
} else if (child_func_name == "partitionSpec") {
partition_spec = child_function;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,10 @@ class WindowVisitorImpl : public ExprVisitorImpl {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, true));
} else if (window_function_name == "row_number") {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, true/*FIXME: force decending*/));
} else {
return arrow::Status::Invalid("window function not supported: " +
window_function_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ class WindowRankKernel : public KernalBase {
public:
WindowRankKernel(arrow::compute::ExecContext* ctx,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<WindowSortKernel::Impl> sorter, bool desc);
std::shared_ptr<WindowSortKernel::Impl> sorter, bool desc, bool is_row_number = false);
static arrow::Status Make(arrow::compute::ExecContext* ctx, std::string function_name,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<KernalBase>* out, bool desc);
Expand All @@ -338,6 +338,7 @@ class WindowRankKernel : public KernalBase {
std::vector<ArrayList> input_cache_;
std::vector<std::shared_ptr<arrow::DataType>> type_list_;
bool desc_;
bool is_row_number_;
};

/*class UniqueArrayKernel : public KernalBase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,12 @@ WindowAggregateFunctionKernel::createBuilder(std::shared_ptr<arrow::DataType> da
WindowRankKernel::WindowRankKernel(
arrow::compute::ExecContext* ctx,
std::vector<std::shared_ptr<arrow::DataType>> type_list,
std::shared_ptr<WindowSortKernel::Impl> sorter, bool desc) {
std::shared_ptr<WindowSortKernel::Impl> sorter, bool desc, bool is_row_number) {
ctx_ = ctx;
type_list_ = type_list;
sorter_ = sorter;
desc_ = desc;
is_row_number_ = is_row_number;
}

arrow::Status WindowRankKernel::Make(
Expand Down Expand Up @@ -296,7 +297,12 @@ arrow::Status WindowRankKernel::Make(
throw JniPendingException("Window Sort codegen failed");
}
}
*out = std::make_shared<WindowRankKernel>(ctx, type_list, sorter, desc);
if (function_name == "row_number") {
*out = std::make_shared<WindowRankKernel>(ctx, type_list, sorter, desc, true);
} else {
*out = std::make_shared<WindowRankKernel>(ctx, type_list, sorter, desc);
}

return arrow::Status::OK();
}

Expand Down Expand Up @@ -526,6 +532,10 @@ arrow::Status WindowRankKernel::AreTheSameValue(const std::vector<ArrayList>& va
std::shared_ptr<ArrayItemIndex> i,
std::shared_ptr<ArrayItemIndex> j,
bool* out) {
if (is_row_number_) {
*out = false;
return arrow::Status::OK();
}
auto typed_array_i =
std::dynamic_pointer_cast<ArrayType>(values.at(i->array_id).at(column));
auto typed_array_j =
Expand Down

0 comments on commit 403aa6c

Please sign in to comment.