Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
[NSE-862] improve row_number() (#1000)
Browse files Browse the repository at this point in the history
* Revert "disable row_number() temporary (#994)"

This reverts commit b973977.

* improve row_number()

Signed-off-by: Yuan Zhou <yuan.zhou@intel.com>
  • Loading branch information
zhouyuan authored Jul 11, 2022
1 parent e3d2a63 commit a6be543
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
try {
breakable {
for (func <- validateWindowFunctions()) {
// TODO(): disable row_number() for stability issue
// if (func._1 == "row_number") {
// allLiteral = false
// break
// }
if (func._1.startsWith("row_number")) {
allLiteral = false
break
}
for (child <- func._2.children) {
if (!child.isInstanceOf[Literal]) {
allLiteral = false
Expand Down Expand Up @@ -197,10 +196,29 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
case None => "rank_asc"
}
case rw: RowNumber =>
"row_number"
val desc: Option[Boolean] = orderSpec.foldLeft[Option[Boolean]](None) {
(desc, s) =>
val currentDesc = s.direction match {
case Ascending => false
case Descending => true
case _ => throw new IllegalStateException
}
if (desc.isEmpty) {
Some(currentDesc)
} else if (currentDesc == desc.get) {
Some(currentDesc)
} else {
throw new UnsupportedOperationException("row_number: clashed rank order found")
}
}
desc match {
case Some(true) => "row_number_desc"
case Some(false) => "row_number_asc"
case None => "row_number_asc"
}
case f => throw new UnsupportedOperationException("unsupported window function: " + f)
}
if (name == "row_number") {
if (name.startsWith("row_number")) {
(name, orderSpec.head.child)
} else {
(name, func)
Expand All @@ -222,10 +240,10 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
} else {
val prev1 = System.nanoTime()
val gWindowFunctions = windowFunctions.map {
case ("row_number", spec) =>
case (row_number_func, spec) if row_number_func.startsWith("row_number") =>
//TODO(): should get attr from orderSpec
val attr = ConverterUtils.getAttrFromExpr(orderSpec.head.child, true)
TreeBuilder.makeFunction("row_number",
TreeBuilder.makeFunction(row_number_func,
List(TreeBuilder.makeField(
Field.nullable(attr.name,
CodeGeneration.getResultType(attr.dataType)))).toList.asJava,
Expand Down Expand Up @@ -263,8 +281,12 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression],
val returnType = ArrowType.Binary.INSTANCE
val fieldType = new FieldType(false, returnType, null)
val resultField = new Field("window_res", fieldType,
windowFunctions.map { case (_, f) =>
CodeGeneration.getResultType(f.dataType)
windowFunctions.map {
case (row_number_func, f) if row_number_func.startsWith("row_number")=>
// row_number will return int32 based indicies
new ArrowType.Int(32, true)
case (_, f) =>
CodeGeneration.getResultType(f.dataType)
}.zipWithIndex.map { case (t, i) =>
Field.nullable(s"window_res_" + i, t)
}.asJava)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ arrow::Status ExprVisitor::MakeWindow(
child_func_name == "min" || child_func_name == "max" ||
child_func_name == "count" || child_func_name == "count_literal" ||
child_func_name == "rank_asc" || child_func_name == "rank_desc" ||
child_func_name == "row_number") {
child_func_name == "row_number_desc" || child_func_name == "row_number_asc") {
window_functions.push_back(child_function);
} else if (child_func_name == "partitionSpec") {
partition_spec = child_function;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,14 @@ class WindowVisitorImpl : public ExprVisitorImpl {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, true));
} else if (window_function_name == "row_number") {
RETURN_NOT_OK(extra::WindowRankKernel::Make(
&p_->ctx_, window_function_name, function_param_type_list, &function_kernel,
true /*FIXME: force decending*/));
} else if (window_function_name == "row_number_desc") {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, true));
} else if (window_function_name == "row_number_asc") {
RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name,
function_param_type_list,
&function_kernel, false));
} else {
return arrow::Status::Invalid("window function not supported: " +
window_function_name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ arrow::Status WindowRankKernel::Make(
throw JniPendingException("Window Sort codegen failed");
}
}
if (function_name == "row_number") {
if (function_name.rfind("row_number", 0) == 0) {
*out = std::make_shared<WindowRankKernel>(ctx, type_list, sorter, desc, true);
} else {
*out = std::make_shared<WindowRankKernel>(ctx, type_list, sorter, desc);
Expand Down

0 comments on commit a6be543

Please sign in to comment.