Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

[NSE-1166] Cover peers' values in sum window function in range mode #1167

Merged
merged 7 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,8 @@ class WindowSortBase : public KernalBase {

std::vector<std::shared_ptr<arrow::DataType>> order_type_list_;

std::vector<ArrayList> values_; // The window function input.
std::vector<ArrayList> values_; // The window function input.
std::vector<ArrayList> sort_values_; // Sort input.
std::vector<std::shared_ptr<arrow::Int32Array>> group_ids_;
int32_t max_group_id_ = 0;
std::vector<std::vector<std::shared_ptr<ArrayItemIndexS>>> sorted_partitions_;
Expand Down Expand Up @@ -414,6 +415,13 @@ class WindowSumKernel : public WindowSortBase {

arrow::Status Finish(ArrayList* out) override;

template <typename ArrayType>
bool isSameSortValue(std::shared_ptr<ArrayItemIndexS> curr_array_index,
std::shared_ptr<ArrayItemIndexS> next_array_index, int col);

int getLastPeerIndex(std::vector<std::shared_ptr<ArrayItemIndexS>>& sorted_partition,
int curr_index);

template <typename VALUE_TYPE, typename CType, typename BuilderType, typename ArrayType,
typename OP>
arrow::Status HandleSortedPartition(
Expand Down
128 changes: 107 additions & 21 deletions native-sql-engine/cpp/src/codegen/arrow_compute/ext/window_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ arrow::Status WindowSortBase::SortToIndicesFinish(
}

arrow::Status WindowSortBase::prepareFinish() {
std::vector<ArrayList> sort_values; // Sort input.
#ifdef DEBUG
std::cout << "[window kernel] Entering Rank Kernel's finish method... " << std::endl;
#endif
Expand Down Expand Up @@ -322,7 +321,7 @@ arrow::Status WindowSortBase::prepareFinish() {
auto column_slice = batch.at(i);
sort_values_batch.push_back(column_slice);
}
sort_values.push_back(sort_values_batch);
sort_values_.push_back(sort_values_batch);
}
#ifdef DEBUG
std::cout << "[window kernel] Finished. " << std::endl;
Expand Down Expand Up @@ -371,7 +370,7 @@ arrow::Status WindowSortBase::prepareFinish() {
std::cout << "[window kernel] Finished. " << std::endl;
#endif

RETURN_NOT_OK(SortToIndicesPrepare(sort_values));
RETURN_NOT_OK(SortToIndicesPrepare(sort_values_));
for (int i = 0; i <= max_group_id_; i++) {
std::vector<std::shared_ptr<ArrayItemIndexS>> partition = partitions_to_sort.at(i);
std::vector<std::shared_ptr<ArrayItemIndexS>> sorted_partition;
Expand Down Expand Up @@ -903,6 +902,77 @@ arrow::Status WindowSumKernel::Finish(ArrayList* out) {
return arrow::Status::OK();
}

template <typename ArrayType>
bool WindowSumKernel::isSameSortValue(std::shared_ptr<ArrayItemIndexS> curr_array_index,
std::shared_ptr<ArrayItemIndexS> next_array_index,
int col) {
auto curr_typed_array = std::dynamic_pointer_cast<ArrayType>(
sort_values_.at(curr_array_index->array_id).at(col));
auto next_typed_array = std::dynamic_pointer_cast<ArrayType>(
sort_values_.at(next_array_index->array_id).at(col));
return (curr_typed_array->GetView(curr_array_index->id) ==
next_typed_array->GetView(next_array_index->id));
}

// Get the final peer index. In range mode, rows are peers if they have the same values
// for the ORDER BY fields. A frame start of CURRENT ROW refers to the first peer row of
// the current row, while a frame end of CURRENT ROW refers to the last peer row of the
// current row.
int WindowSumKernel::getLastPeerIndex(
std::vector<std::shared_ptr<ArrayItemIndexS>>& sorted_partition, int curr_index) {
bool isSame = true;
int lastPeerIndex = curr_index;
std::shared_ptr<ArrayItemIndexS> curr_array_index = sorted_partition.at(curr_index);
for (int i = curr_index + 1; i < sorted_partition.size(); i++) {
std::shared_ptr<ArrayItemIndexS> next_array_index = sorted_partition.at(i);
// Compare sort key.
for (int col = 0; col < order_type_list_.size(); col++) {
std::shared_ptr<arrow::DataType> value_type = order_type_list_[col];
switch (value_type->id()) {
#define PROCESS_SUPPORTED_COMMON_TYPES_SORT(PROC) \
PROC(arrow::UInt8Type, arrow::UInt8Array) \
PROC(arrow::Int8Type, arrow::Int8Array) \
PROC(arrow::UInt16Type, arrow::UInt16Array) \
PROC(arrow::Int16Type, arrow::Int16Array) \
PROC(arrow::UInt32Type, arrow::UInt32Array) \
PROC(arrow::Int32Type, arrow::Int32Array) \
PROC(arrow::UInt64Type, arrow::UInt64Array) \
PROC(arrow::Int64Type, arrow::Int64Array) \
PROC(arrow::FloatType, arrow::FloatArray) \
PROC(arrow::DoubleType, arrow::DoubleArray)

#define PROCESS(VALUE_TYPE, ARRAY_TYPE) \
case VALUE_TYPE::type_id: { \
isSame = \
isSame && isSameSortValue<ARRAY_TYPE>(curr_array_index, next_array_index, col); \
} break;
PROCESS_SUPPORTED_COMMON_TYPES_SORT(PROCESS)
#undef PROCESS
#undef PROCESS_SUPPORTED_COMMON_TYPES_SORT
case arrow::StringType::type_id: {
isSame = isSame && isSameSortValue<arrow::StringArray>(curr_array_index,
next_array_index, col);
} break;
default: {
throw std::runtime_error("window function: unsupported input type: " +
value_type->name());
} break;
} // switch
// Jump from the sort col loop.
if (!isSame) {
break;
}
} // sort col loop.

if (isSame) {
lastPeerIndex = i;
} else {
break;
}
} // sorted_partition loop
return lastPeerIndex;
}

// ArrayType: input ArrayType. CType: result CType. BuilderType: result BuilderType.
// ResArrayType: Result ArrayType.
template <typename ArrayType, typename CType, typename BuilderType, typename ResArrayType,
Expand All @@ -928,29 +998,45 @@ arrow::Status WindowSumKernel::HandleSortedPartition(
sorted_partitions.at(i);
CType parition_sum_by_current = (CType)0;
bool is_valid_value_found = false;
for (int j = 0; j < sorted_partition.size(); j++) {
int j = 0;
while (j < sorted_partition.size()) {
std::shared_ptr<ArrayItemIndexS> index = sorted_partition.at(j);
for (int column_id = 0; column_id < type_list_.size(); column_id++) {
auto typed_array = std::dynamic_pointer_cast<ArrayType>(
values.at(index->array_id).at(column_id));
// If the first value in one partition (ordered) is null, the result is null.
// If there is valid value before null, the result for null is as same as the
// above. So for same value in ordered col, the sum result may be different from
// vanilla's.
if (typed_array->null_count() > 0 && typed_array->IsNull(index->id)) {
if (!is_valid_value_found) {
validity[index->array_id][index->id] = false;
} else {
sum_array[index->array_id][index->id] = parition_sum_by_current;
validity[index->array_id][index->id] = true;
}
int column_id = 0; // One col input.
auto typed_array =
std::dynamic_pointer_cast<ArrayType>(values.at(index->array_id).at(column_id));
// If the first value in one partition (ordered) is null, the result is null.
// If there is valid value before null, the result for null is as same as the
// above.
if (typed_array->null_count() > 0 && typed_array->IsNull(index->id)) {
if (!is_valid_value_found) {
validity[index->array_id][index->id] = false;
} else {
is_valid_value_found = true;
parition_sum_by_current =
parition_sum_by_current + (CType)op(typed_array, index->id);
sum_array[index->array_id][index->id] = parition_sum_by_current;
validity[index->array_id][index->id] = true;
}
j++;
} else {
is_valid_value_found = true;
parition_sum_by_current =
parition_sum_by_current + (CType)op(typed_array, index->id);
int lastPeerIndex = getLastPeerIndex(sorted_partition, j);
// Calculate values with peers considered.
for (int k = j + 1; k <= lastPeerIndex; k++) {
std::shared_ptr<ArrayItemIndexS> peer_index = sorted_partition.at(k);
auto peer_typed_array = std::dynamic_pointer_cast<ArrayType>(
values.at(peer_index->array_id).at(column_id));
parition_sum_by_current =
parition_sum_by_current + (CType)op(peer_typed_array, peer_index->id);
}
// Set values for all peers whose sort keys are same in a group.
for (int k = j; k <= lastPeerIndex; k++) {
std::shared_ptr<ArrayItemIndexS> peer_index = sorted_partition.at(k);
auto peer_typed_array = std::dynamic_pointer_cast<ArrayType>(
values.at(peer_index->array_id).at(column_id));
sum_array[peer_index->array_id][peer_index->id] = parition_sum_by_current;
validity[peer_index->array_id][peer_index->id] = true;
}
j = lastPeerIndex + 1;
}
}
}
Expand Down
41 changes: 41 additions & 0 deletions native-sql-engine/cpp/src/tests/arrow_compute_test_window.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,5 +507,46 @@ TEST(TestArrowComputeWindow, SumOrderedTest) {
ASSERT_NOT_OK(Equals(*expected_result.get(), *(out.at(0).get())));
}

// Test case: sort key has repeat values and there are multiple peers need to be
// considered in range mode.
TEST(TestArrowComputeWindow, SumOrderedWithMultiplePeersTest) {
std::shared_ptr<arrow::RecordBatch> input_batch;
auto sch =
arrow::schema({field("col_int", arrow::int32()), field("col_dec", arrow::int32())});
std::vector<std::string> input_data = {"[1, 2, 1]", "[39, 37, 39]"};
MakeInputBatch(input_data, sch, &input_batch);

std::shared_ptr<Field> res = field("window_res", arrow::int64());

auto f_window = TreeExprBuilder::MakeExpression(
TreeExprBuilder::MakeFunction(
"window",
{TreeExprBuilder::MakeFunction(
"sum_desc", {TreeExprBuilder::MakeField(field("col_dec", arrow::int32()))},
null()),
TreeExprBuilder::MakeFunction(
"partitionSpec",
{TreeExprBuilder::MakeField(field("col_int", arrow::int32()))}, null()),
TreeExprBuilder::MakeFunction(
"orderSpec",
{TreeExprBuilder::MakeField(field("col_dec", arrow::int32()))}, null())},
binary()),
res);

arrow::compute::ExecContext ctx;
std::shared_ptr<CodeGenerator> expr;
std::vector<std::shared_ptr<arrow::RecordBatch>> out;
ASSERT_NOT_OK(
CreateCodeGenerator(ctx.memory_pool(), sch, {f_window}, {res}, &expr, true))
ASSERT_NOT_OK(expr->evaluate(input_batch, nullptr))
ASSERT_NOT_OK(expr->finish(&out))

std::shared_ptr<arrow::RecordBatch> expected_result;
std::vector<std::string> expected_output_data = {"[78, 37, 78]"};

MakeInputBatch(expected_output_data, arrow::schema({res}), &expected_result);
ASSERT_NOT_OK(Equals(*expected_result.get(), *(out.at(0).get())));
}

} // namespace codegen
} // namespace sparkcolumnarplugin