diff --git a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala index 4ae836a09..aa1656a60 100644 --- a/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala +++ b/native-sql-engine/core/src/main/scala/com/intel/oap/execution/ColumnarWindowExec.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, Literal, MakeDecimal, NamedExpression, PredicateHelper, Rank, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeReference, Cast, Descending, Expression, Literal, MakeDecimal, NamedExpression, PredicateHelper, Rank, RowNumber, SortOrder, UnscaledValue, WindowExpression, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -61,7 +61,7 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], override def output: Seq[Attribute] = child.output ++ windowExpression.map(_.toAttribute) - buildCheck() + //buildCheck() override def requiredChildDistribution: Seq[Distribution] = { if (isLocal) { @@ -191,9 +191,15 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], case Some(false) => "rank_asc" case None => "rank_asc" } + case rw: RowNumber => + "row_number" case f => throw new UnsupportedOperationException("unsupported window function: " + f) } - (name, func) + if (name == "row_number") { + (name, orderSpec.head.child) + } else { + (name, func) + } } if (windowFunctions.isEmpty) { throw new UnsupportedOperationException("zero window functions" + @@ -210,7 +216,16 @@ case class ColumnarWindowExec(windowExpression: Seq[NamedExpression], Iterator.empty } else { val prev1 = System.nanoTime() - val gWindowFunctions = windowFunctions.map { case (n, f) => + val gWindowFunctions = windowFunctions.map { + case ("row_number", fc) => + val attr = ConverterUtils.getAttrFromExpr(fc, true) + TreeBuilder.makeFunction("row_number", + List(TreeBuilder.makeField( + Field.nullable(attr.name, + CodeGeneration.getResultType(attr.dataType)))).toList.asJava, + NoneType.NONE_TYPE + ) + case (n, f) => TreeBuilder.makeFunction(n, f.children .flatMap { diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc index e57642f24..7967815a1 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor.cc @@ -277,10 +277,12 @@ arrow::Status ExprVisitor::MakeWindow( for (const auto& child : node.children()) { auto child_function = std::dynamic_pointer_cast(child); auto child_func_name = child_function->descriptor()->name(); + std::cout << "window func name: " << child_func_name << std::endl; if (child_func_name == "sum" || child_func_name == "avg" || child_func_name == "min" || child_func_name == "max" || child_func_name == "count" || child_func_name == "count_literal" || - child_func_name == "rank_asc" || child_func_name == "rank_desc") { + child_func_name == "rank_asc" || child_func_name == "rank_desc" || + child_func_name == "row_number") { window_functions.push_back(child_function); } else if (child_func_name == "partitionSpec") { partition_spec = child_function; diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h index 6229f70eb..39bf46545 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/expr_visitor_impl.h @@ -177,6 +177,10 @@ class WindowVisitorImpl : public ExprVisitorImpl { RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name, function_param_type_list, &function_kernel, true)); + } else if (window_function_name == "row_number") { + RETURN_NOT_OK(extra::WindowRankKernel::Make(&p_->ctx_, window_function_name, + function_param_type_list, + &function_kernel, true/*FIXME: force decending*/)); } else { return arrow::Status::Invalid("window function not supported: " + window_function_name); diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h index 520f4b60f..0f2df4960 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/kernels_ext.h @@ -315,7 +315,7 @@ class WindowRankKernel : public KernalBase { public: WindowRankKernel(arrow::compute::ExecContext* ctx, std::vector> type_list, - std::shared_ptr sorter, bool desc); + std::shared_ptr sorter, bool desc, bool is_row_number = false); static arrow::Status Make(arrow::compute::ExecContext* ctx, std::string function_name, std::vector> type_list, std::shared_ptr* out, bool desc); @@ -338,6 +338,7 @@ class WindowRankKernel : public KernalBase { std::vector input_cache_; std::vector> type_list_; bool desc_; + bool is_row_number_; }; /*class UniqueArrayKernel : public KernalBase { diff --git a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/window_kernel.cc b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/window_kernel.cc index 3423c8c81..d24a6b5e0 100644 --- a/native-sql-engine/cpp/src/codegen/arrow_compute/ext/window_kernel.cc +++ b/native-sql-engine/cpp/src/codegen/arrow_compute/ext/window_kernel.cc @@ -243,11 +243,12 @@ WindowAggregateFunctionKernel::createBuilder(std::shared_ptr da WindowRankKernel::WindowRankKernel( arrow::compute::ExecContext* ctx, std::vector> type_list, - std::shared_ptr sorter, bool desc) { + std::shared_ptr sorter, bool desc, bool is_row_number) { ctx_ = ctx; type_list_ = type_list; sorter_ = sorter; desc_ = desc; + is_row_number_ = is_row_number; } arrow::Status WindowRankKernel::Make( @@ -296,7 +297,12 @@ arrow::Status WindowRankKernel::Make( throw JniPendingException("Window Sort codegen failed"); } } - *out = std::make_shared(ctx, type_list, sorter, desc); + if (function_name == "row_number") { + *out = std::make_shared(ctx, type_list, sorter, desc, true); + } else { + *out = std::make_shared(ctx, type_list, sorter, desc); + } + return arrow::Status::OK(); } @@ -526,6 +532,10 @@ arrow::Status WindowRankKernel::AreTheSameValue(const std::vector& va std::shared_ptr i, std::shared_ptr j, bool* out) { + if (is_row_number_) { + *out = false; + return arrow::Status::OK(); + } auto typed_array_i = std::dynamic_pointer_cast(values.at(i->array_id).at(column)); auto typed_array_j =