Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize aggregate FILTER clause #7458

Merged
merged 9 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .unreleased/vectorized-agg-filter
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implements: #7458 Support vecorized aggregation with aggregate FILTER clauses that are also vectorizable
58 changes: 53 additions & 5 deletions tsl/src/compression/arrow_c_data_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,73 @@ arrow_set_row_validity(uint64 *bitmap, size_t row_number, bool value)
}

/*
* AND two optional arrow validity bitmaps into the given storage.
* Combine the validity bitmaps into the given storage.
*/
static inline const uint64 *
arrow_combine_validity(size_t num_words, uint64 *restrict storage, const uint64 *filter1,
const uint64 *filter2)
const uint64 *filter2, const uint64 *filter3)
{
/*
* Any and all of the filters can be null. For simplicity, move the non-null
* filters to the leading positions.
*/
const uint64 *tmp;
#define SWAP(X, Y) \
tmp = (X); \
(X) = (Y); \
(Y) = tmp;

if (filter1 == NULL)
{
return filter2;
/*
* We have at least one NULL that goes to the last position.
*/
SWAP(filter1, filter3);

if (filter1 == NULL)
{
/*
* We have another NULL that goes to the second position.
*/
SWAP(filter1, filter2);
}
}
else
{
if (filter2 == NULL)
{
/*
* We have at least one NULL that goes to the last position.
*/
SWAP(filter2, filter3);
}
}
#undef SWAP

Assert(filter2 == NULL || filter1 != NULL);
Assert(filter3 == NULL || filter2 != NULL);

if (filter2 == NULL)
{
/* Either have one non-null filter, or all of them are null. */
return filter1;
}

for (size_t i = 0; i < num_words; i++)
if (filter3 == NULL)
{
/* Have two non-null filters. */
for (size_t i = 0; i < num_words; i++)
{
storage[i] = filter1[i] & filter2[i];
}
}
else
{
storage[i] = filter1[i] & filter2[i];
/* Have three non-null filters. */
for (size_t i = 0; i < num_words; i++)
{
storage[i] = filter1[i] & filter2[i] & filter3[i];
}
}

return storage;
Expand Down
16 changes: 1 addition & 15 deletions tsl/src/nodes/decompress_chunk/compressed_batch.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,6 @@
#include "nodes/decompress_chunk/vector_predicates.h"
#include "nodes/decompress_chunk/vector_quals.h"

/*
* VectorQualState for a compressed batch used to pass
* DecompressChunk-specific data to vector qual functions that are shared
* across scan nodes.
*/
typedef struct CompressedBatchVectorQualState
{
VectorQualState vqstate;
DecompressBatchState *batch_state;
DecompressContext *dcontext;
} CompressedBatchVectorQualState;

/*
* Create a single-value ArrowArray of an arithmetic type. This is a specialized
* function because arithmetic types have a particular layout of ArrowArrays.
Expand Down Expand Up @@ -312,7 +300,7 @@ decompress_column(DecompressContext *dcontext, DecompressBatchState *batch_state
* VectorQualState->get_arrow_array() function used to interface with the
* vector qual code across different scan nodes.
*/
static const ArrowArray *
const ArrowArray *
compressed_batch_get_arrow_array(VectorQualState *vqstate, Expr *expr, bool *is_default_value)
{
CompressedBatchVectorQualState *cbvqstate = (CompressedBatchVectorQualState *) vqstate;
Expand Down Expand Up @@ -360,8 +348,6 @@ compressed_batch_get_arrow_array(VectorQualState *vqstate, Expr *expr, bool *is_
var->varattno);
Assert(column_description != NULL);
Assert(column_description->typid == var->vartype);
Ensure(column_description->type == COMPRESSED_COLUMN,
"only compressed columns are supported in vectorized quals");

CompressedColumnValues *column_values = &batch_state->compressed_columns[column_index];

Expand Down
16 changes: 16 additions & 0 deletions tsl/src/nodes/decompress_chunk/compressed_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "compression/compression.h"
#include "nodes/decompress_chunk/decompress_context.h"
#include "nodes/decompress_chunk/vector_quals.h"
#include <executor/tuptable.h>

typedef struct ArrowArray ArrowArray;
Expand Down Expand Up @@ -172,3 +173,18 @@ compressed_batch_current_tuple(DecompressBatchState *batch_state)
Assert(batch_state->per_batch_context != NULL);
return &batch_state->decompressed_scan_slot_data.base;
}

/*
* VectorQualState for a compressed batch used to pass
* DecompressChunk-specific data to vector qual functions that are shared
* across scan nodes.
*/
typedef struct CompressedBatchVectorQualState
{
VectorQualState vqstate;
DecompressBatchState *batch_state;
DecompressContext *dcontext;
} CompressedBatchVectorQualState;

const ArrowArray *compressed_batch_get_arrow_array(VectorQualState *vqstate, Expr *expr,
bool *is_default_value);
6 changes: 0 additions & 6 deletions tsl/src/nodes/decompress_chunk/planner.c
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,6 @@ typedef struct

} DecompressionMapContext;

typedef struct VectorQualInfoDecompressChunk
{
VectorQualInfo vqinfo;
const UncompressedColumnInfo *colinfo;
} VectorQualInfoDecompressChunk;

static bool *
build_vector_attrs_array(const UncompressedColumnInfo *colinfo, const CompressionInfo *info)
{
Expand Down
50 changes: 50 additions & 0 deletions tsl/src/nodes/vector_agg/exec.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
#include <nodes/extensible.h>
#include <nodes/makefuncs.h>
#include <nodes/nodeFuncs.h>
#include <optimizer/optimizer.h>

#include "nodes/vector_agg/exec.h"

#include "compression/arrow_c_data_interface.h"
#include "guc.h"
#include "nodes/decompress_chunk/compressed_batch.h"
#include "nodes/decompress_chunk/exec.h"
#include "nodes/decompress_chunk/vector_quals.h"
#include "nodes/vector_agg.h"

static int
Expand Down Expand Up @@ -67,6 +69,17 @@ vector_agg_begin(CustomScanState *node, EState *estate, int eflags)
DecompressChunkState *decompress_state =
(DecompressChunkState *) linitial(vector_agg_state->custom.custom_ps);

/*
* Set up the helper structures used to evaluate stable expressions in
* vectorized FILTER clauses.
*/
PlannerGlobal glob = {
.boundParams = node->ss.ps.state->es_param_list_info,
};
PlannerInfo root = {
.glob = &glob,
};

/*
* The aggregated targetlist with Aggrefs is in the custom scan targetlist
* of the custom scan node that is performing the vectorized aggregation.
Expand Down Expand Up @@ -149,6 +162,12 @@ vector_agg_begin(CustomScanState *node, EState *estate, int eflags)
{
def->input_offset = -1;
}

if (aggref->aggfilter != NULL)
{
Node *constified = estimate_expression_value(&root, (Node *) aggref->aggfilter);
def->filter_clauses = list_make1(constified);
}
}
else
{
Expand Down Expand Up @@ -293,6 +312,37 @@ vector_agg_exec(CustomScanState *node)
dcontext->ps->instrument->tuplecount += not_filtered_rows;
}

/*
* Compute the vectorized filters for the aggregate function FILTER
* clauses.
*/
const int naggs = vector_agg_state->num_agg_defs;
for (int i = 0; i < naggs; i++)
{
VectorAggDef *agg_def = &vector_agg_state->agg_defs[i];
if (agg_def->filter_clauses == NIL)
{
continue;
}
CompressedBatchVectorQualState cbvqstate = {
.vqstate = {
.vectorized_quals_constified = agg_def->filter_clauses,
.num_results = batch_state->total_batch_rows,
.per_vector_mcxt = batch_state->per_batch_context,
.slot = compressed_slot,
.get_arrow_array = compressed_batch_get_arrow_array,
},
.batch_state = batch_state,
.dcontext = dcontext,
};
VectorQualState *vqstate = &cbvqstate.vqstate;
vector_qual_compute(vqstate);
agg_def->filter_result = vqstate->vector_qual_result;
}

/*
* Finally, pass the compressed batch to the grouping policy.
*/
grouping->gp_add_batch(grouping, batch_state);
}

Expand Down
2 changes: 2 additions & 0 deletions tsl/src/nodes/vector_agg/exec.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ typedef struct VectorAggDef
VectorAggFunctions func;
int input_offset;
int output_offset;
List *filter_clauses;
uint64 *filter_result;
} VectorAggDef;

typedef struct GroupingColumn
Expand Down
18 changes: 10 additions & 8 deletions tsl/src/nodes/vector_agg/grouping_policy_batch.c
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ compute_single_aggregate(GroupingPolicyBatch *policy, DecompressBatchState *batc
const uint64 *filter = arrow_combine_validity(num_words,
policy->tmp_filter,
batch_state->vector_qual_result,
agg_def->filter_result,
arg_validity_bitmap);

/*
Expand All @@ -166,15 +167,16 @@ compute_single_aggregate(GroupingPolicyBatch *policy, DecompressBatchState *batc
/*
* Scalar argument, or count(*). Have to also count the valid rows in
* the batch.
*/
const int n = arrow_num_valid(filter, batch_state->total_batch_rows);

/*
*
* The batches that are fully filtered out by vectorized quals should
* have been skipped by the caller.
* have been skipped by the caller, but we also have to check for the
* case when no rows match the aggregate FILTER clause.
*/
Assert(n > 0);
agg_def->func.agg_scalar(agg_state, arg_datum, arg_isnull, n, agg_extra_mctx);
const int n = arrow_num_valid(filter, batch_state->total_batch_rows);
if (n > 0)
{
agg_def->func.agg_scalar(agg_state, arg_datum, arg_isnull, n, agg_extra_mctx);
}
}
}

Expand All @@ -185,7 +187,7 @@ gp_batch_add_batch(GroupingPolicy *gp, DecompressBatchState *batch_state)

/*
* Allocate the temporary filter array for computing the combined results of
* batch filter and column validity.
* batch filter, aggregate filter and column validity.
*/
const size_t num_words = (batch_state->total_batch_rows + 63) / 64;
if (num_words > policy->num_tmp_filter_words)
Expand Down
Loading
Loading