Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(query): support spill for new agg hashtable #14905

Merged
merged 11 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/query/expression/src/aggregate/aggregate_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use super::probe_state::ProbeState;
use crate::aggregate::payload_row::row_match_columns;
use crate::group_hash_columns;
use crate::new_sel;
use crate::read;
use crate::types::DataType;
use crate::AggregateFunctionRef;
use crate::Column;
Expand Down Expand Up @@ -159,8 +160,8 @@ impl AggregateHashTable {
if !self.payload.aggrs.is_empty() {
for i in 0..row_count {
state.state_places[i] = unsafe {
StateAddr::new(core::ptr::read::<u64>(
state.addresses[i].add(self.payload.state_offset) as _,
StateAddr::new(read::<u64>(
state.addresses[i].add(self.payload.state_offset) as _
) as usize)
};
}
Expand Down Expand Up @@ -365,7 +366,7 @@ impl AggregateHashTable {
if !self.payload.aggrs.is_empty() {
for i in 0..row_count {
flush_state.probe_state.state_places[i] = unsafe {
StateAddr::new(core::ptr::read::<u64>(
StateAddr::new(read::<u64>(
flush_state.probe_state.addresses[i].add(self.payload.state_offset)
as _,
) as usize)
Expand Down
4 changes: 2 additions & 2 deletions src/query/expression/src/aggregate/partitioned_payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use itertools::Itertools;

use super::payload::Payload;
use super::probe_state::ProbeState;
use crate::read;
use crate::types::DataType;
use crate::AggregateFunctionRef;
use crate::Column;
Expand Down Expand Up @@ -217,8 +218,7 @@ impl PartitionedPayload {
for idx in 0..rows {
state.addresses[idx] = other.data_ptr(page, idx + state.flush_page_row);

let hash =
unsafe { core::ptr::read::<u64>(state.addresses[idx].add(self.hash_offset) as _) };
let hash = unsafe { read::<u64>(state.addresses[idx].add(self.hash_offset) as _) };

let partition_idx = ((hash & self.mask_v) >> self.shift_v) as usize;

Expand Down
14 changes: 7 additions & 7 deletions src/query/expression/src/aggregate/payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use strength_reduce::StrengthReducedU64;
use super::payload_row::rowformat_size;
use super::payload_row::serialize_column_to_rowformat;
use crate::get_layout_offsets;
use crate::read;
use crate::store;
use crate::types::DataType;
use crate::AggregateFunctionRef;
Expand Down Expand Up @@ -237,14 +238,14 @@ impl Payload {
for idx in select_vector.iter().take(new_group_rows).copied() {
unsafe {
let dst = address[idx].add(write_offset);
store(val, dst as *mut u8);
store::<u8>(&val, dst as *mut u8);
}
}
} else {
for idx in select_vector.iter().take(new_group_rows).copied() {
unsafe {
let dst = address[idx].add(write_offset);
store(bitmap.get_bit(idx) as u8, dst as *mut u8);
store::<u8>(&(bitmap.get_bit(idx) as u8), dst as *mut u8);
}
}
}
Expand Down Expand Up @@ -275,7 +276,7 @@ impl Payload {
for idx in select_vector.iter().take(new_group_rows).copied() {
unsafe {
let dst = address[idx].add(write_offset);
store(group_hashes[idx], dst as *mut u8);
store::<u64>(&group_hashes[idx], dst as *mut u8);
}
}

Expand All @@ -287,7 +288,7 @@ impl Payload {
let place = self.arena.alloc_layout(layout);
unsafe {
let dst = address[idx].add(write_offset);
store(place.as_ptr() as u64, dst as *mut u8);
store::<u64>(&(place.as_ptr() as u64), dst as *mut u8);
}

let place = StateAddr::from(place);
Expand Down Expand Up @@ -365,8 +366,7 @@ impl Payload {
for idx in 0..rows {
state.addresses[idx] = self.data_ptr(page, idx + state.flush_page_row);

let hash =
unsafe { core::ptr::read::<u64>(state.addresses[idx].add(self.hash_offset) as _) };
let hash = unsafe { read::<u64>(state.addresses[idx].add(self.hash_offset) as _) };

let partition_idx = (hash % mods) as usize;

Expand Down Expand Up @@ -403,7 +403,7 @@ impl Drop for Payload {
for page in self.pages.iter() {
for row in 0..page.rows {
unsafe {
let state_place = StateAddr::new(core::ptr::read::<u64>(
let state_place = StateAddr::new(read::<u64>(
self.data_ptr(page, row).add(self.state_offset) as _,
)
as usize);
Expand Down
34 changes: 15 additions & 19 deletions src/query/expression/src/aggregate/payload_flush.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use ethnum::i256;
use super::partitioned_payload::PartitionedPayload;
use super::payload::Payload;
use super::probe_state::ProbeState;
use crate::read;
use crate::types::binary::BinaryColumn;
use crate::types::binary::BinaryColumnBuilder;
use crate::types::decimal::Decimal;
Expand Down Expand Up @@ -196,13 +197,13 @@ impl Payload {
for idx in 0..rows {
state.addresses[idx] = self.data_ptr(page, idx + state.flush_page_row);
state.probe_state.group_hashes[idx] =
unsafe { core::ptr::read::<u64>(state.addresses[idx].add(self.hash_offset) as _) };
unsafe { read::<u64>(state.addresses[idx].add(self.hash_offset) as _) };

if !self.aggrs.is_empty() {
state.state_places[idx] = unsafe {
StateAddr::new(core::ptr::read::<u64>(
state.addresses[idx].add(self.state_offset) as _,
) as usize)
StateAddr::new(
read::<u64>(state.addresses[idx].add(self.state_offset) as _) as usize,
)
};
}
}
Expand Down Expand Up @@ -268,9 +269,8 @@ impl Payload {
state: &mut PayloadFlushState,
) -> Column {
let len = state.probe_state.row_count;
let iter = (0..len).map(|idx| unsafe {
core::ptr::read::<T::Scalar>(state.addresses[idx].add(col_offset) as _)
});
let iter = (0..len)
.map(|idx| unsafe { read::<T::Scalar>(state.addresses[idx].add(col_offset) as _) });
let col = T::column_from_iter(iter, &[]);
T::upcast_column(col)
}
Expand All @@ -283,8 +283,8 @@ impl Payload {
) -> Column {
let len = state.probe_state.row_count;
let iter = (0..len).map(|idx| unsafe {
core::ptr::read::<<DecimalType<Num> as ValueType>::Scalar>(
state.addresses[idx].add(col_offset) as _,
read::<<DecimalType<Num> as ValueType>::Scalar>(
state.addresses[idx].add(col_offset) as _
)
});
let col = DecimalType::<Num>::column_from_iter(iter, &[]);
Expand All @@ -301,11 +301,9 @@ impl Payload {

unsafe {
for idx in 0..len {
let str_len =
core::ptr::read::<u32>(state.addresses[idx].add(col_offset) as _) as usize;
let data_address =
core::ptr::read::<u64>(state.addresses[idx].add(col_offset + 4) as _) as usize
as *const u8;
let str_len = read::<u32>(state.addresses[idx].add(col_offset) as _) as usize;
let data_address = read::<u64>(state.addresses[idx].add(col_offset + 4) as _)
as usize as *const u8;

let scalar = std::slice::from_raw_parts(data_address, str_len);

Expand Down Expand Up @@ -335,11 +333,9 @@ impl Payload {

unsafe {
for idx in 0..len {
let str_len =
core::ptr::read::<u32>(state.addresses[idx].add(col_offset) as _) as usize;
let data_address =
core::ptr::read::<u64>(state.addresses[idx].add(col_offset + 4) as _) as usize
as *const u8;
let str_len = read::<u32>(state.addresses[idx].add(col_offset) as _) as usize;
let data_address = read::<u64>(state.addresses[idx].add(col_offset + 4) as _)
as usize as *const u8;

let scalar = std::slice::from_raw_parts(data_address, str_len);
let scalar: Scalar = bincode_deserialize_from_slice(scalar).unwrap();
Expand Down
56 changes: 30 additions & 26 deletions src/query/expression/src/aggregate/payload_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use databend_common_io::prelude::bincode_deserialize_from_slice;
use databend_common_io::prelude::bincode_serialize_into_buf;
use ethnum::i256;

use crate::read;
use crate::store;
use crate::types::binary::BinaryColumn;
use crate::types::decimal::DecimalColumn;
Expand Down Expand Up @@ -76,15 +77,18 @@ pub unsafe fn serialize_column_to_rowformat(
Column::Number(v) => with_number_mapped_type!(|NUM_TYPE| match v {
NumberColumn::NUM_TYPE(buffer) => {
for index in select_vector.iter().take(rows).copied() {
store(buffer[index], address[index].add(offset) as *mut u8);
store::<NUM_TYPE>(&buffer[index], address[index].add(offset) as *mut u8);
}
}
}),
Column::Decimal(v) => {
with_decimal_mapped_type!(|DECIMAL_TYPE| match v {
DecimalColumn::DECIMAL_TYPE(buffer, _) => {
for index in select_vector.iter().take(rows).copied() {
store(buffer[index], address[index].add(offset) as *mut u8);
store::<DECIMAL_TYPE>(
&buffer[index],
address[index].add(offset) as *mut u8,
);
}
}
})
Expand All @@ -94,12 +98,12 @@ pub unsafe fn serialize_column_to_rowformat(
let val: u8 = if v.unset_bits() == 0 { 1 } else { 0 };
// faster path
for index in select_vector.iter().take(rows).copied() {
store(val, address[index].add(offset) as *mut u8);
store::<u8>(&val, address[index].add(offset) as *mut u8);
}
} else {
for index in select_vector.iter().take(rows).copied() {
store(
v.get_bit(index) as u8,
store::<u8>(
&(v.get_bit(index) as u8),
address[index].add(offset) as *mut u8,
);
}
Expand All @@ -108,31 +112,31 @@ pub unsafe fn serialize_column_to_rowformat(
Column::Binary(v) | Column::Bitmap(v) | Column::Variant(v) | Column::Geometry(v) => {
for index in select_vector.iter().take(rows).copied() {
let data = arena.alloc_slice_copy(v.index_unchecked(index));
store(data.len() as u32, address[index].add(offset) as *mut u8);
store(
data.as_ptr() as u64,
store::<u32>(&(data.len() as u32), address[index].add(offset) as *mut u8);
store::<u64>(
&(data.as_ptr() as u64),
address[index].add(offset + 4) as *mut u8,
);
}
}
Column::String(v) => {
for index in select_vector.iter().take(rows).copied() {
let data = arena.alloc_str(v.index_unchecked(index));
store(data.len() as u32, address[index].add(offset) as *mut u8);
store(
data.as_ptr() as u64,
store::<u32>(&(data.len() as u32), address[index].add(offset) as *mut u8);
Freejww marked this conversation as resolved.
Show resolved Hide resolved
store::<u64>(
&(data.as_ptr() as u64),
address[index].add(offset + 4) as *mut u8,
);
}
}
Column::Timestamp(buffer) => {
for index in select_vector.iter().take(rows).copied() {
store(buffer[index], address[index].add(offset) as *mut u8);
store::<i64>(&buffer[index], address[index].add(offset) as *mut u8);
}
}
Column::Date(buffer) => {
for index in select_vector.iter().take(rows).copied() {
store(buffer[index], address[index].add(offset) as *mut u8);
store::<i32>(&buffer[index], address[index].add(offset) as *mut u8);
}
}
Column::Nullable(c) => serialize_column_to_rowformat(
Expand All @@ -153,9 +157,9 @@ pub unsafe fn serialize_column_to_rowformat(
bincode_serialize_into_buf(scratch, &s).unwrap();

let data = arena.alloc_slice_copy(scratch);
store(data.len() as u32, address[index].add(offset) as *mut u8);
store(
data.as_ptr() as u64,
store::<u32>(&(data.len() as u32), address[index].add(offset) as *mut u8);
store::<u64>(
&(data.as_ptr() as u64),
address[index].add(offset + 4) as *mut u8,
);
}
Expand Down Expand Up @@ -362,19 +366,19 @@ unsafe fn row_match_binary_column(
for idx in select_vector[..*count].iter() {
let idx = *idx;
let validity_address = address[idx].add(validity_offset);
let is_set2 = core::ptr::read::<u8>(validity_address as _) != 0;
let is_set2 = read::<u8>(validity_address as _) != 0;
let is_set = is_all_set || validity.get_bit_unchecked(idx);

if is_set && is_set2 {
let len_address = address[idx].add(col_offset);
let address = address[idx].add(col_offset + 4);
let len = core::ptr::read::<u32>(len_address as _) as usize;
let len = read::<u32>(len_address as _) as usize;

let value = BinaryType::index_column_unchecked(col, idx);
if len != value.len() {
equal = false;
} else {
let data_address = core::ptr::read::<u64>(address as _) as usize as *const u8;
let data_address = read::<u64>(address as _) as usize as *const u8;
let scalar = std::slice::from_raw_parts(data_address, len);
equal = databend_common_hashtable::fast_memcmp(scalar, value);
}
Expand All @@ -396,13 +400,13 @@ unsafe fn row_match_binary_column(
let len_address = address[idx].add(col_offset);
let address = address[idx].add(col_offset + 4);

let len = core::ptr::read::<u32>(len_address as _) as usize;
let len = read::<u32>(len_address as _) as usize;

let value = BinaryType::index_column_unchecked(col, idx);
if len != value.len() {
equal = false;
} else {
let data_address = core::ptr::read::<u64>(address as _) as usize as *const u8;
let data_address = read::<u64>(address as _) as usize as *const u8;
let scalar = std::slice::from_raw_parts(data_address, len);

equal = databend_common_hashtable::fast_memcmp(scalar, value);
Expand Down Expand Up @@ -444,11 +448,11 @@ unsafe fn row_match_column_type<T: ArgType>(
for idx in select_vector[..*count].iter() {
let idx = *idx;
let validity_address = address[idx].add(validity_offset);
let is_set2 = core::ptr::read::<u8>(validity_address as _) != 0;
let is_set2 = read::<u8>(validity_address as _) != 0;
let is_set = is_all_set || validity.get_bit_unchecked(idx);
if is_set && is_set2 {
let address = address[idx].add(col_offset);
let scalar = core::ptr::read::<<T as ValueType>::Scalar>(address as _);
let scalar = read::<<T as ValueType>::Scalar>(address as _);
let value = T::index_column_unchecked(&col, idx);
let value = T::to_owned_scalar(value);

Expand All @@ -470,7 +474,7 @@ unsafe fn row_match_column_type<T: ArgType>(
let idx = *idx;
let value = T::index_column_unchecked(&col, idx);
let address = address[idx].add(col_offset);
let scalar = core::ptr::read::<<T as ValueType>::Scalar>(address as _);
let scalar = read::<<T as ValueType>::Scalar>(address as _);
let value = T::to_owned_scalar(value);

if scalar.eq(&value) {
Expand Down Expand Up @@ -502,12 +506,12 @@ unsafe fn row_match_generic_column(
for idx in select_vector[..*count].iter() {
let idx = *idx;
let len_address = address[idx].add(col_offset);
let len = core::ptr::read::<u32>(len_address as _) as usize;
let len = read::<u32>(len_address as _) as usize;

let address = address[idx].add(col_offset + 4);

let value = AnyType::index_column_unchecked(col, idx);
let data_address = core::ptr::read::<u64>(address as _) as usize as *const u8;
let data_address = read::<u64>(address as _) as usize as *const u8;

let scalar = std::slice::from_raw_parts(data_address, len);
let scalar: Scalar = bincode_deserialize_from_slice(scalar).unwrap();
Expand Down
13 changes: 10 additions & 3 deletions src/query/expression/src/kernels/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,17 @@ pub unsafe fn set_vec_len_by_ptr<T>(vec: &mut Vec<T>, ptr: *const T) {
}

/// # Safety
/// # As: core::ptr::write
/// # As: core::ptr::copy_nonoverlapping
#[inline]
pub unsafe fn store<T: Copy>(val: T, ptr: *mut u8) {
core::ptr::write(ptr as _, val)
pub unsafe fn store<T: Copy>(val: &T, ptr: *mut u8) {
core::ptr::copy_nonoverlapping(val as *const T as *const u8, ptr, std::mem::size_of::<T>());
}

/// # Safety
/// # As: core::ptr::read_unaligned
#[inline]
pub unsafe fn read<T>(ptr: *const u8) -> T {
core::ptr::read_unaligned::<T>(ptr as _)
}

/// Iterates over an arbitrarily aligned byte buffer
Expand Down
Loading