Skip to content

Commit

Permalink
fix(query): Pass ci Test when enable `enable_experimental_aggregate_h…
Browse files Browse the repository at this point in the history
…ashtable `. (databendlabs#14544)
  • Loading branch information
sundy-li authored and yufan022 committed Jun 18, 2024
1 parent 21bb809 commit 3840f2e
Show file tree
Hide file tree
Showing 15 changed files with 244 additions and 108 deletions.
46 changes: 34 additions & 12 deletions src/query/expression/src/aggregate/aggregate_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,11 @@ impl AggregateHashTable {
state: &mut ProbeState,
group_columns: &[Column],
params: &[Vec<Column>],
agg_states: &[Column],
row_count: usize,
) -> Result<usize> {
if row_count <= BATCH_ADD_SIZE {
self.add_groups_inner(state, group_columns, params, row_count)
self.add_groups_inner(state, group_columns, params, agg_states, row_count)
} else {
let mut new_count = 0;
for start in (0..row_count).step_by(BATCH_ADD_SIZE) {
Expand All @@ -104,9 +105,18 @@ impl AggregateHashTable {
.iter()
.map(|c| c.iter().map(|x| x.slice(start..end)).collect())
.collect::<Vec<_>>();
let agg_states = agg_states
.iter()
.map(|c| c.slice(start..end))
.collect::<Vec<_>>();

new_count +=
self.add_groups_inner(state, &step_group_columns, &step_params, end - start)?;
new_count += self.add_groups_inner(
state,
&step_group_columns,
&step_params,
&agg_states,
end - start,
)?;
}
Ok(new_count)
}
Expand All @@ -118,6 +128,7 @@ impl AggregateHashTable {
state: &mut ProbeState,
group_columns: &[Column],
params: &[Vec<Column>],
agg_states: &[Column],
row_count: usize,
) -> Result<usize> {
state.row_count = row_count;
Expand All @@ -132,19 +143,30 @@ impl AggregateHashTable {
state.addresses[i].add(self.payload.state_offset) as _,
) as usize)
};
debug_assert_eq!(usize::from(state.state_places[i]) % 8, 0);
}

let state_places = &state.state_places.as_slice()[0..row_count];

for ((aggr, params), addr_offset) in self
.payload
.aggrs
.iter()
.zip(params.iter())
.zip(self.payload.state_addr_offsets.iter())
{
aggr.accumulate_keys(state_places, *addr_offset, params, row_count)?;
if agg_states.is_empty() {
for ((aggr, params), addr_offset) in self
.payload
.aggrs
.iter()
.zip(params.iter())
.zip(self.payload.state_addr_offsets.iter())
{
aggr.accumulate_keys(state_places, *addr_offset, params, row_count)?;
}
} else {
for ((aggr, agg_state), addr_offset) in self
.payload
.aggrs
.iter()
.zip(agg_states.iter())
.zip(self.payload.state_addr_offsets.iter())
{
aggr.batch_merge(state_places, *addr_offset, agg_state)?;
}
}
}

Expand Down
16 changes: 12 additions & 4 deletions src/query/expression/src/aggregate/group_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use ethnum::i256;
use ordered_float::OrderedFloat;

use crate::types::decimal::DecimalType;
use crate::types::AnyType;
use crate::types::ArgType;
use crate::types::BinaryType;
use crate::types::BitmapType;
Expand All @@ -27,9 +28,11 @@ use crate::types::NumberDataType;
use crate::types::NumberType;
use crate::types::StringType;
use crate::types::TimestampType;
use crate::types::ValueType;
use crate::types::VariantType;
use crate::with_number_mapped_type;
use crate::Column;
use crate::ScalarRef;

const NULL_HASH_VAL: u64 = 0xd1cefa08eb382d69;

Expand Down Expand Up @@ -94,14 +97,12 @@ pub fn combine_group_hash_column<const IS_FIRST: bool>(c: &Column, values: &mut
}
}
}
DataType::Tuple(_) => todo!(),
DataType::Array(_) => todo!(),
DataType::Map(_) => todo!(),
DataType::Generic(_) => unreachable!(),
_ => combine_group_hash_type_column::<IS_FIRST, AnyType>(c, values),
}
}

fn combine_group_hash_type_column<const IS_FIRST: bool, T: ArgType>(
fn combine_group_hash_type_column<const IS_FIRST: bool, T: ValueType>(
col: &Column,
values: &mut [u64],
) where
Expand Down Expand Up @@ -244,3 +245,10 @@ impl AggHash for OrderedFloat<f64> {
}
}
}

impl AggHash for ScalarRef<'_> {
#[inline(always)]
fn agg_hash(&self) -> u64 {
self.to_string().as_bytes().agg_hash()
}
}
15 changes: 7 additions & 8 deletions src/query/expression/src/aggregate/payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,20 @@ impl Payload {
for col in group_columns {
if let Column::Nullable(c) = col {
let bitmap = &c.validity;
if bitmap.unset_bits() == 0 {
if bitmap.unset_bits() == 0 || bitmap.unset_bits() == bitmap.len() {
let val: u8 = if bitmap.unset_bits() == 0 { 1 } else { 0 };
// faster path
for idx in select_vector.iter().take(new_group_rows).copied() {
unsafe {
let dst = address[idx].add(write_offset);
store(1, dst as *mut u8);
store(val, dst as *mut u8);
}
}
} else if bitmap.unset_bits() != bitmap.len() {
} else {
for idx in select_vector.iter().take(new_group_rows).copied() {
if bitmap.get_bit(idx) {
unsafe {
let dst = address[idx].add(write_offset);
store(1, dst as *mut u8);
}
unsafe {
let dst = address[idx].add(write_offset);
store(bitmap.get_bit(idx) as u8, dst as *mut u8);
}
}
}
Expand Down
61 changes: 53 additions & 8 deletions src/query/expression/src/aggregate/payload_flush.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,31 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use databend_common_io::prelude::bincode_deserialize_from_slice;
use ethnum::i256;

use super::partitioned_payload::PartitionedPayload;
use super::payload::Payload;
use super::probe_state::ProbeState;
use crate::types::binary::BinaryColumn;
use crate::types::binary::BinaryColumnBuilder;
use crate::types::decimal::Decimal;
use crate::types::decimal::DecimalType;
use crate::types::nullable::NullableColumn;
use crate::types::string::StringColumn;
use crate::types::ArgType;
use crate::types::BooleanType;
use crate::types::DataType;
use crate::types::DateType;
use crate::types::DecimalSize;
use crate::types::NumberDataType;
use crate::types::NumberType;
use crate::types::TimestampType;
use crate::types::ValueType;
use crate::with_number_mapped_type;
use crate::Column;
use crate::ColumnBuilder;
use crate::Scalar;
use crate::StateAddr;
use crate::BATCH_SIZE;

Expand Down Expand Up @@ -160,11 +166,11 @@ impl Payload {
self.flush_type_column::<NumberType<NUM_TYPE>>(col_offset, state),
}),
DataType::Decimal(v) => match v {
crate::types::DecimalDataType::Decimal128(_) => {
self.flush_type_column::<DecimalType<i128>>(col_offset, state)
crate::types::DecimalDataType::Decimal128(s) => {
self.flush_decimal_column::<i128>(col_offset, state, s)
}
crate::types::DecimalDataType::Decimal256(_) => {
self.flush_type_column::<DecimalType<i256>>(col_offset, state)
crate::types::DecimalDataType::Decimal256(s) => {
self.flush_decimal_column::<i256>(col_offset, state, s)
}
},
DataType::Timestamp => self.flush_type_column::<TimestampType>(col_offset, state),
Expand All @@ -174,10 +180,7 @@ impl Payload {
DataType::Bitmap => Column::Bitmap(self.flush_binary_column(col_offset, state)),
DataType::Variant => Column::Variant(self.flush_binary_column(col_offset, state)),
DataType::Nullable(_) => unreachable!(),
DataType::Array(_) => todo!(),
DataType::Map(_) => todo!(),
DataType::Tuple(_) => todo!(),
DataType::Generic(_) => unreachable!(),
other => self.flush_generic_column(&other, col_offset, state),
};

let validity_offset = self.validity_offsets[col_index];
Expand Down Expand Up @@ -207,6 +210,22 @@ impl Payload {
T::upcast_column(col)
}

fn flush_decimal_column<Num: Decimal>(
&self,
col_offset: usize,
state: &mut PayloadFlushState,
decimal_size: DecimalSize,
) -> Column {
let len = state.probe_state.row_count;
let iter = (0..len).map(|idx| unsafe {
core::ptr::read::<<DecimalType<Num> as ValueType>::Scalar>(
state.addresses[idx].add(col_offset) as _,
)
});
let col = DecimalType::<Num>::column_from_iter(iter, &[]);
Num::upcast_column(col, decimal_size)
}

fn flush_binary_column(
&self,
col_offset: usize,
Expand Down Expand Up @@ -239,4 +258,30 @@ impl Payload {
) -> StringColumn {
unsafe { StringColumn::from_binary_unchecked(self.flush_binary_column(col_offset, state)) }
}

fn flush_generic_column(
&self,
data_type: &DataType,
col_offset: usize,
state: &mut PayloadFlushState,
) -> Column {
let len = state.probe_state.row_count;
let mut builder = ColumnBuilder::with_capacity(data_type, len);

unsafe {
for idx in 0..len {
let str_len =
core::ptr::read::<u32>(state.addresses[idx].add(col_offset) as _) as usize;
let data_address =
core::ptr::read::<u64>(state.addresses[idx].add(col_offset + 4) as _) as usize
as *const u8;

let scalar = std::slice::from_raw_parts(data_address, str_len);
let scalar: Scalar = bincode_deserialize_from_slice(scalar).unwrap();

builder.push(scalar.as_ref());
}
}
builder.build()
}
}
Loading

0 comments on commit 3840f2e

Please sign in to comment.