Skip to content

Commit

Permalink
fix: fix incorrect agg spill in new agg hashtable (#14995)
Browse files Browse the repository at this point in the history
* fix agg spill

* fix agg spill

---------

Co-authored-by: jw <freejw@gmail.com>
  • Loading branch information
Freejww and jw authored Mar 19, 2024
1 parent 98652d1 commit 7171319
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 39 deletions.
13 changes: 13 additions & 0 deletions src/common/metrics/src/metrics/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ pub fn metrics_inc_aggregate_partial_hashtable_allocated_bytes(c: u64) {
AGGREGATE_PARTIAL_HASHTABLE_ALLOCATED_BYTES.inc_by(c);
}

pub fn metrics_inc_group_by_partial_spill_count() {
let labels = &vec![("spill", "group_by_partial_spill".to_string())];
SPILL_COUNT.get_or_create(labels).inc();
}

pub fn metrics_inc_group_by_partial_spill_cell_count(c: u64) {
AGGREGATE_PARTIAL_SPILL_CELL_COUNT.inc_by(c);
}

pub fn metrics_inc_group_by_partial_hashtable_allocated_bytes(c: u64) {
AGGREGATE_PARTIAL_HASHTABLE_ALLOCATED_BYTES.inc_by(c);
}

pub fn metrics_inc_group_by_spill_write_count() {
let labels = &vec![("spill", "group_by_spill".to_string())];
SPILL_WRITE_COUNT.get_or_create(labels).inc();
Expand Down
21 changes: 21 additions & 0 deletions src/query/expression/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ mod payload_row;
mod probe_state;

use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::sync::Arc;

pub use aggregate_function::*;
Expand Down Expand Up @@ -108,4 +109,24 @@ impl HashTableConfig {

self
}

pub fn update_current_max_radix_bits(&self) {
loop {
let current_max_radix_bits = self.current_max_radix_bits.load(Ordering::SeqCst);
if current_max_radix_bits < self.max_radix_bits
&& self
.current_max_radix_bits
.compare_exchange(
current_max_radix_bits,
self.max_radix_bits,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_err()
{
continue;
}
break;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,16 @@ impl SerializedPayload {
&self,
group_types: Vec<DataType>,
aggrs: Vec<Arc<dyn AggregateFunction>>,
radix_bits: u64,
arena: Arc<Bump>,
) -> Result<PartitionedPayload> {
let rows_num = self.data_block.num_rows();
let radix_bits = self.max_partition_count.trailing_zeros() as u64;
let config = HashTableConfig::default().with_initial_radix_bits(radix_bits);
let mut state = ProbeState::default();
let agg_len = aggrs.len();
let group_len = group_types.len();
let mut hashtable = AggregateHashTable::new_directly(
group_types,
aggrs,
config,
rows_num,
Arc::new(Bump::new()),
);
let mut hashtable =
AggregateHashTable::new_directly(group_types, aggrs, config, rows_num, arena);

let agg_states = (0..agg_len)
.map(|i| {
Expand All @@ -96,6 +92,8 @@ impl SerializedPayload {
let _ =
hashtable.add_groups(&mut state, &group_columns, &[vec![]], &agg_states, rows_num)?;

hashtable.payload.mark_min_cardinality();

Ok(hashtable.payload)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<Method: HashMethodBounds> TransformFinalAggregate<Method> {

fn transform_agg_hashtable(&mut self, meta: AggregateMeta<Method, usize>) -> Result<DataBlock> {
let mut agg_hashtable: Option<AggregateHashTable> = None;
if let AggregateMeta::Partitioned { bucket: _, data } = meta {
if let AggregateMeta::Partitioned { bucket, data } = meta {
for bucket_data in data {
match bucket_data {
AggregateMeta::AggregateHashTable(payload) => match agg_hashtable.as_mut() {
Expand All @@ -92,16 +92,24 @@ impl<Method: HashMethodBounds> TransformFinalAggregate<Method> {
},
AggregateMeta::Serialized(payload) => match agg_hashtable.as_mut() {
Some(ht) => {
debug_assert!(bucket == payload.bucket);
let arena = Arc::new(Bump::new());
let payload = payload.convert_to_partitioned_payload(
self.params.group_data_types.clone(),
self.params.aggregate_functions.clone(),
0,
arena,
)?;
ht.combine_payloads(&payload, &mut self.flush_state)?;
}
None => {
debug_assert!(bucket == payload.bucket);
let arena = Arc::new(Bump::new());
let payload = payload.convert_to_partitioned_payload(
self.params.group_data_types.clone(),
self.params.aggregate_functions.clone(),
0,
arena,
)?;
let capacity =
AggregateHashTable::get_capacity_for_count(payload.len());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use databend_common_expression::BlockMetaInfoDowncast;
use databend_common_expression::Column;
use databend_common_expression::DataBlock;
use databend_common_expression::HashTableConfig;
use databend_common_expression::PayloadFlushState;
use databend_common_expression::ProbeState;
use databend_common_functions::aggregates::StateAddr;
use databend_common_functions::aggregates::StateAddrs;
Expand All @@ -50,7 +51,6 @@ use crate::pipelines::processors::transforms::group_by::HashMethodBounds;
use crate::pipelines::processors::transforms::group_by::PartitionedHashMethod;
use crate::pipelines::processors::transforms::group_by::PolymorphicKeysHelper;
use crate::sessions::QueryContext;

#[allow(clippy::enum_variant_names)]
enum HashTable<Method: HashMethodBounds> {
MovedOut,
Expand Down Expand Up @@ -401,9 +401,21 @@ impl<Method: HashMethodBounds> AccumulatingTransform for TransformPartialAggrega

let group_types = v.payload.group_types.clone();
let aggrs = v.payload.aggrs.clone();
let config = v.config.clone();
v.config.update_current_max_radix_bits();
let config = v
.config
.clone()
.with_initial_radix_bits(v.config.max_radix_bits);

let mut state = PayloadFlushState::default();

// repartition to max for normalization
let partitioned_payload = v
.payload
.repartition(1 << config.max_radix_bits, &mut state);

let blocks = vec![DataBlock::empty_with_meta(
AggregateMeta::<Method, usize>::create_agg_spilling(v.payload),
AggregateMeta::<Method, usize>::create_agg_spilling(partitioned_payload),
)];

let arena = Arc::new(Bump::new());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl<Method: HashMethodBounds> TransformFinalGroupBy<Method> {

fn transform_agg_hashtable(&mut self, meta: AggregateMeta<Method, ()>) -> Result<DataBlock> {
let mut agg_hashtable: Option<AggregateHashTable> = None;
if let AggregateMeta::Partitioned { bucket: _, data } = meta {
if let AggregateMeta::Partitioned { bucket, data } = meta {
for bucket_data in data {
match bucket_data {
AggregateMeta::AggregateHashTable(payload) => match agg_hashtable.as_mut() {
Expand All @@ -85,16 +85,24 @@ impl<Method: HashMethodBounds> TransformFinalGroupBy<Method> {
},
AggregateMeta::Serialized(payload) => match agg_hashtable.as_mut() {
Some(ht) => {
debug_assert!(bucket == payload.bucket);
let arena = Arc::new(Bump::new());
let payload = payload.convert_to_partitioned_payload(
self.params.group_data_types.clone(),
self.params.aggregate_functions.clone(),
0,
arena,
)?;
ht.combine_payloads(&payload, &mut self.flush_state)?;
}
None => {
debug_assert!(bucket == payload.bucket);
let arena = Arc::new(Bump::new());
let payload = payload.convert_to_partitioned_payload(
self.params.group_data_types.clone(),
self.params.aggregate_functions.clone(),
0,
arena,
)?;
let capacity =
AggregateHashTable::get_capacity_for_count(payload.len());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ use databend_common_expression::AggregateHashTable;
use databend_common_expression::Column;
use databend_common_expression::DataBlock;
use databend_common_expression::HashTableConfig;
use databend_common_expression::PayloadFlushState;
use databend_common_expression::ProbeState;
use databend_common_hashtable::HashtableLike;
use databend_common_metrics::transform::*;
use databend_common_pipeline_core::processors::InputPort;
use databend_common_pipeline_core::processors::OutputPort;
use databend_common_pipeline_core::processors::Processor;
Expand Down Expand Up @@ -212,6 +214,15 @@ impl<Method: HashMethodBounds> AccumulatingTransform for TransformPartialGroupBy
{
if let HashTable::PartitionedHashTable(v) = std::mem::take(&mut self.hash_table)
{
// perf
{
metrics_inc_group_by_partial_spill_count();
metrics_inc_group_by_partial_spill_cell_count(1);
metrics_inc_group_by_partial_hashtable_allocated_bytes(
v.allocated_bytes() as u64,
);
}

let _dropper = v._dropper.clone();
let blocks = vec![DataBlock::empty_with_meta(
AggregateMeta::<Method, ()>::create_spilling(v),
Expand All @@ -234,11 +245,30 @@ impl<Method: HashMethodBounds> AccumulatingTransform for TransformPartialGroupBy
|| GLOBAL_MEM_STAT.get_memory_usage() as usize >= self.settings.max_memory_usage)
{
if let HashTable::AggregateHashTable(v) = std::mem::take(&mut self.hash_table) {
// perf
{
metrics_inc_group_by_partial_spill_count();
metrics_inc_group_by_partial_spill_cell_count(1);
metrics_inc_group_by_partial_hashtable_allocated_bytes(
v.allocated_bytes() as u64,
);
}

let group_types = v.payload.group_types.clone();
let aggrs = v.payload.aggrs.clone();
let config = v.config.clone();
v.config.update_current_max_radix_bits();
let config = v
.config
.clone()
.with_initial_radix_bits(v.config.max_radix_bits);
let mut state = PayloadFlushState::default();

// repartition to max for normalization
let partitioned_payload = v
.payload
.repartition(1 << config.max_radix_bits, &mut state);
let blocks = vec![DataBlock::empty_with_meta(
AggregateMeta::<Method, ()>::create_agg_spilling(v.payload),
AggregateMeta::<Method, ()>::create_agg_spilling(partitioned_payload),
)];

let arena = Arc::new(Bump::new());
Expand Down
Loading

0 comments on commit 7171319

Please sign in to comment.