Skip to content

Commit

Permalink
Feature duckdb#1272: Window Distinct Merging
Browse files Browse the repository at this point in the history
Fully parallelise the sorting for distinct aggregates.
  • Loading branch information
Richard Wesley committed Aug 6, 2024
1 parent df6faca commit 49f2cfb
Showing 1 changed file with 100 additions and 13 deletions.
113 changes: 100 additions & 13 deletions src/execution/window_segment_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "duckdb/common/algorithm.hpp"
#include "duckdb/common/helper.hpp"
#include "duckdb/common/sort/partition_state.hpp"
#include "duckdb/common/vector_operations/vector_operations.hpp"
#include "duckdb/execution/merge_sort_tree.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
Expand Down Expand Up @@ -1398,25 +1399,42 @@ WindowDistinctAggregator::WindowDistinctAggregator(AggregateObject aggr, const v
: WindowAggregator(std::move(aggr), arg_types, result_type, exclude_mode_p), context(context) {
}

class WindowDistinctAggregatorLocalState;

class WindowDistinctAggregatorGlobalState : public WindowAggregatorGlobalState {
public:
using GlobalSortStatePtr = unique_ptr<GlobalSortState>;
class DistinctSortTree;

WindowDistinctAggregatorGlobalState(const WindowDistinctAggregator &aggregator, idx_t group_count);

bool TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate);

void Finalize(const FrameStats &stats);

// Single threaded sorting for now
ClientContext &context;
GlobalSortStatePtr global_sort;
idx_t memory_per_thread;

//! Finalize guard
mutex lock;
//! Finalize stage
PartitionSortStage stage = PartitionSortStage::INIT;
//! Tasks launched
idx_t total_tasks = 0;
//! Tasks launched
idx_t tasks_assigned = 0;
//! Tasks landed
mutable atomic<idx_t> tasks_completed;

//! The sorted payload data types (partition index)
vector<LogicalType> payload_types;
//! The aggregate arguments + partition index
vector<LogicalType> sort_types;

//! Sorting operations
GlobalSortStatePtr global_sort;

//! The merge sort tree for the aggregate.
unique_ptr<DistinctSortTree> merge_sort_tree;

Expand All @@ -1428,7 +1446,7 @@ class WindowDistinctAggregatorGlobalState : public WindowAggregatorGlobalState {

WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(const WindowDistinctAggregator &aggregator,
idx_t group_count)
: WindowAggregatorGlobalState(aggregator, group_count), context(aggregator.context),
: WindowAggregatorGlobalState(aggregator, group_count), context(aggregator.context), tasks_completed(0),
levels_flat_native(aggregator.aggr) {
payload_types.emplace_back(LogicalType::UBIGINT);

Expand Down Expand Up @@ -1460,10 +1478,14 @@ class WindowDistinctAggregatorLocalState : public WindowAggregatorState {
explicit WindowDistinctAggregatorLocalState(const WindowDistinctAggregatorGlobalState &aggregator);

void Sink(DataChunk &arg_chunk, idx_t input_idx, optional_ptr<SelectionVector> filter_sel, idx_t filtered);
void ExecuteTask();
void Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, const DataChunk &bounds, Vector &result,
idx_t count, idx_t row_idx);

//! Thread-local sorting data
LocalSortState local_sort;
//! Finalize stage
PartitionSortStage stage = PartitionSortStage::INIT;

protected:
//! Flush the accumulated intermediate states into the result states
Expand Down Expand Up @@ -1544,13 +1566,87 @@ void WindowDistinctAggregatorLocalState::Sink(DataChunk &arg_chunk, idx_t input_
}
}

void WindowDistinctAggregatorLocalState::ExecuteTask() {
auto &global_sort = *gastate.global_sort;
switch (stage) {
case PartitionSortStage::INIT:
// AddLocalState is thread-safe
global_sort.AddLocalState(local_sort);
break;
case PartitionSortStage::MERGE: {
MergeSorter merge_sorter(global_sort, global_sort.buffer_manager);
merge_sorter.PerformInMergeRound();
break;
}
default:
break;
}

++gastate.tasks_completed;
}

bool WindowDistinctAggregatorGlobalState::TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate) {
lock_guard<mutex> stage_guard(lock);

switch (stage) {
case PartitionSortStage::INIT:
// Wait for all the local sorts to be processed
if (tasks_completed < locals) {
return false;
}
global_sort->PrepareMergePhase();
total_tasks = global_sort->sorted_blocks.size() / 2;
if (!total_tasks) {
break;
}
global_sort->InitializeMergeRound();
lstate.stage = stage = PartitionSortStage::MERGE;
tasks_assigned = 1;
tasks_completed = 0;
return true;
case PartitionSortStage::MERGE:
if (tasks_assigned < total_tasks) {
lstate.stage = PartitionSortStage::MERGE;
++tasks_assigned;
return true;
} else if (tasks_completed < tasks_assigned) {
return false;
}
global_sort->CompleteMergeRound(true);
total_tasks = global_sort->sorted_blocks.size() / 2;
if (!total_tasks) {
break;
}
global_sort->InitializeMergeRound();
lstate.stage = PartitionSortStage::MERGE;
tasks_assigned = 1;
tasks_completed = 0;
return true;
default:
break;
}

lstate.stage = stage = PartitionSortStage::SORTED;

return true;
}

void WindowDistinctAggregator::Finalize(WindowAggregatorState &gsink, WindowAggregatorState &lstate,
const FrameStats &stats) {
auto &gdsink = gsink.Cast<WindowDistinctAggregatorGlobalState>();
auto &ldstate = lstate.Cast<WindowDistinctAggregatorLocalState>();

// AddLocalState is thread-safe
gdsink.global_sort->AddLocalState(ldstate.local_sort);
// 5: Sort sorted lexicographically increasing
ldstate.ExecuteTask();

// Merge in parallel
while (gdsink.stage != PartitionSortStage::SORTED) {
if (gdsink.TryPrepareNextStage(ldstate)) {
ldstate.ExecuteTask();
} else {
std::this_thread::yield();
}
}

// Last one out turns off the lights!
if (++gdsink.finalized == gdsink.locals) {
Expand All @@ -1568,15 +1664,6 @@ class WindowDistinctAggregatorGlobalState::DistinctSortTree : public MergeSortTr
};

void WindowDistinctAggregatorGlobalState::Finalize(const FrameStats &stats) {
// 5: Sort sorted lexicographically increasing
global_sort->PrepareMergePhase();
while (global_sort->sorted_blocks.size() > 1) {
global_sort->InitializeMergeRound();
MergeSorter merge_sorter(*global_sort, global_sort->buffer_manager);
merge_sorter.PerformInMergeRound();
global_sort->CompleteMergeRound(true);
}

DataChunk scan_chunk;
scan_chunk.Initialize(Allocator::DefaultAllocator(), payload_types);

Expand Down

0 comments on commit 49f2cfb

Please sign in to comment.