diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ce56ca4f7dfd..e595fe4aa245 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -40,8 +40,9 @@ pub(crate) use single_group_by::primitive::HashValue; use crate::aggregates::{ group_values::single_group_by::{ - bytes::GroupValuesByes, bytes_view::GroupValuesBytesView, - primitive::GroupValuesPrimitive, + bytes::GroupValuesByes, + bytes_view::GroupValuesBytesView, + primitive::{GroupValuesLargePrimitive, GroupValuesPrimitive}, }, order::GroupOrdering, }; @@ -134,6 +135,12 @@ pub(crate) fn new_group_values( if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); + macro_rules! large_downcast_helper { + ($t:ty, $d:ident) => { + return Ok(Box::new(GroupValuesLargePrimitive::<$t>::new($d.clone()))) + }; + } + macro_rules! downcast_helper { ($t:ty, $d:ident) => { return Ok(Box::new(GroupValuesPrimitive::<$t>::new($d.clone()))) @@ -169,7 +176,7 @@ pub(crate) fn new_group_values( TimeUnit::Nanosecond => downcast_helper!(TimestampNanosecondType, d), }, DataType::Decimal128(_, _) => { - downcast_helper!(Decimal128Type, d); + large_downcast_helper!(Decimal128Type, d); } DataType::Utf8 => { return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs new file mode 100644 index 000000000000..6e8cc23ee240 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::single_group_by::primitive::{ + emit_internal, HashValue, +}; +use crate::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::array::{ + cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, +}; +use arrow::datatypes::DataType; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_expr::EmitTo; +use hashbrown::hash_table::HashTable; +use std::mem::size_of; + +/// A [`GroupValues`] storing a single column of large primitive values (bits > 64) +/// +/// This specialization is significantly faster than using the more general +/// purpose `Row`s format +pub struct GroupValuesLargePrimitive { + /// The data type of the output array + data_type: DataType, + /// Stores the `(group_index, hash)` based on the hash of its value + /// + /// We also store `hash` is for reducing cost of rehashing. Such cost + /// is obvious in high cardinality group by situation. + /// More details can see: + /// + /// + map: HashTable<(usize, u64)>, + /// The group index of the null value if any + null_group: Option, + /// The values for each group index + values: Vec, + /// The random state used to generate hashes + random_state: RandomState, +} + +impl GroupValuesLargePrimitive { + pub fn new(data_type: DataType) -> Self { + assert!(PrimitiveArray::::is_compatible(&data_type)); + Self { + data_type, + map: HashTable::with_capacity(128), + values: Vec::with_capacity(128), + null_group: None, + random_state: Default::default(), + } + } +} + +impl GroupValues for GroupValuesLargePrimitive +where + T::Native: HashValue, +{ + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + assert_eq!(cols.len(), 1); + groups.clear(); + + for v in cols[0].as_primitive::() { + let group_id = match v { + None => *self.null_group.get_or_insert_with(|| { + let group_id = self.values.len(); + self.values.push(Default::default()); + group_id + }), + Some(key) => { + let state = &self.random_state; + let hash = key.hash(state); + let insert = self.map.entry( + hash, + |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, + |&(_, h)| h, + ); + + match insert { + hashbrown::hash_table::Entry::Occupied(o) => o.get().0, + hashbrown::hash_table::Entry::Vacant(v) => { + let g = self.values.len(); + v.insert((g, hash)); + self.values.push(key); + g + } + } + } + }; + groups.push(group_id) + } + Ok(()) + } + + fn size(&self) -> usize { + self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() + } + + fn is_empty(&self) -> bool { + self.values.is_empty() + } + + fn len(&self) -> usize { + self.values.len() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + emit_internal::( + emit_to, + &mut self.values, + &mut self.null_group, + &mut self.map, + self.data_type.clone(), + ) + } + + fn clear_shrink(&mut self, batch: &RecordBatch) { + let count = batch.num_rows(); + self.values.clear(); + self.values.shrink_to(count); + self.map.clear(); + self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs similarity index 57% rename from datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs rename to datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs index 279caa50b0a6..7fa4638d12e2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs @@ -32,6 +32,9 @@ use hashbrown::hash_table::HashTable; use std::mem::size_of; use std::sync::Arc; +mod large_primitive; +pub use large_primitive::GroupValuesLargePrimitive; + /// A trait to allow hashing of floating point numbers pub(crate) trait HashValue { fn hash(&self, state: &RandomState) -> u64; @@ -74,27 +77,29 @@ macro_rules! hash_float { hash_float!(f16, f32, f64); -/// A [`GroupValues`] storing a single column of primitive values +/// A [`GroupValues`] storing a single column of normal primitive values (bits <= 64) /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format pub struct GroupValuesPrimitive { /// The data type of the output array data_type: DataType, - /// Stores the `(group_index, hash)` based on the hash of its value + /// Stores the `(group_index, group_value)` /// - /// We also store `hash` is for reducing cost of rehashing. Such cost - /// is obvious in high cardinality group by situation. + /// We directly store copy of `group_value` for not only efficient + /// rehashing, but also efficient probing. /// More details can see: - /// + /// /// - map: HashTable<(usize, u64)>, + map: HashTable<(usize, T::Native)>, /// The group index of the null value if any null_group: Option, /// The values for each group index values: Vec, /// The random state used to generate hashes random_state: RandomState, + + append_row_indices: Vec, } impl GroupValuesPrimitive { @@ -106,6 +111,7 @@ impl GroupValuesPrimitive { values: Vec::with_capacity(128), null_group: None, random_state: Default::default(), + append_row_indices: Vec::new(), } } } @@ -116,13 +122,18 @@ where { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { assert_eq!(cols.len(), 1); + let col = cols[0].as_primitive::(); + groups.clear(); + self.append_row_indices.clear(); - for v in cols[0].as_primitive::() { + let mut num_total_groups = self.values.len(); + for (row_index, v) in col.iter().enumerate() { let group_id = match v { None => *self.null_group.get_or_insert_with(|| { - let group_id = self.values.len(); - self.values.push(Default::default()); + let group_id = num_total_groups; + self.append_row_indices.push(row_index as u32); + num_total_groups += 1; group_id }), Some(key) => { @@ -130,16 +141,17 @@ where let hash = key.hash(state); let insert = self.map.entry( hash, - |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, - |&(_, h)| h, + |&(_, v)| v.is_eq(key), + |&(_, v)| v.hash(state), ); match insert { hashbrown::hash_table::Entry::Occupied(o) => o.get().0, hashbrown::hash_table::Entry::Vacant(v) => { - let g = self.values.len(); - v.insert((g, hash)); - self.values.push(key); + let g = num_total_groups; + v.insert((g, key)); + self.append_row_indices.push(row_index as u32); + num_total_groups += 1; g } } @@ -147,11 +159,23 @@ where }; groups.push(group_id) } + + // If all are new groups, we just extend it + if self.append_row_indices.len() == col.len() { + self.values.extend_from_slice(col.values()); + } else { + let col_values = col.values(); + for &row_index in self.append_row_indices.iter() { + self.values.push(col_values[row_index as usize]); + } + } + Ok(()) } fn size(&self) -> usize { - self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() + self.map.capacity() * size_of::<(usize, T::Native)>() + + self.values.allocated_size() } fn is_empty(&self) -> bool { @@ -163,55 +187,13 @@ where } fn emit(&mut self, emit_to: EmitTo) -> Result> { - fn build_primitive( - values: Vec, - null_idx: Option, - ) -> PrimitiveArray { - let nulls = null_idx.map(|null_idx| { - let mut buffer = NullBufferBuilder::new(values.len()); - buffer.append_n_non_nulls(null_idx); - buffer.append_null(); - buffer.append_n_non_nulls(values.len() - null_idx - 1); - // NOTE: The inner builder must be constructed as there is at least one null - buffer.finish().unwrap() - }); - PrimitiveArray::::new(values.into(), nulls) - } - - let array: PrimitiveArray = match emit_to { - EmitTo::All => { - self.map.clear(); - build_primitive(std::mem::take(&mut self.values), self.null_group.take()) - } - EmitTo::First(n) => { - self.map.retain(|entry| { - // Decrement group index by n - let group_idx = entry.0; - match group_idx.checked_sub(n) { - // Group index was >= n, shift value down - Some(sub) => { - entry.0 = sub; - true - } - // Group index was < n, so remove from table - None => false, - } - }); - let null_group = match &mut self.null_group { - Some(v) if *v >= n => { - *v -= n; - None - } - Some(_) => self.null_group.take(), - None => None, - }; - let mut split = self.values.split_off(n); - std::mem::swap(&mut self.values, &mut split); - build_primitive(split, null_group) - } - }; - - Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) + emit_internal::( + emit_to, + &mut self.values, + &mut self.null_group, + &mut self.map, + self.data_type.clone(), + ) } fn clear_shrink(&mut self, batch: &RecordBatch) { @@ -222,3 +204,61 @@ where self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared } } + +pub(crate) fn emit_internal( + emit_to: EmitTo, + values: &mut Vec, + null_group: &mut Option, + map: &mut HashTable<(usize, K)>, + data_type: DataType, +) -> Result> { + fn build_primitive( + values: Vec, + null_idx: Option, + ) -> PrimitiveArray { + let nulls = null_idx.map(|null_idx| { + let mut buffer = NullBufferBuilder::new(values.len()); + buffer.append_n_non_nulls(null_idx); + buffer.append_null(); + buffer.append_n_non_nulls(values.len() - null_idx - 1); + // NOTE: The inner builder must be constructed as there is at least one null + buffer.finish().unwrap() + }); + PrimitiveArray::::new(values.into(), nulls) + } + + let array: PrimitiveArray = match emit_to { + EmitTo::All => { + map.clear(); + build_primitive(std::mem::take(values), null_group.take()) + } + EmitTo::First(n) => { + map.retain(|entry| { + // Decrement group index by n + let group_idx = entry.0; + match group_idx.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => { + entry.0 = sub; + true + } + // Group index was < n, so remove from table + None => false, + } + }); + let null_group = match null_group { + Some(v) if *v >= n => { + *v -= n; + None + } + Some(_) => null_group.take(), + None => None, + }; + let mut split = values.split_off(n); + std::mem::swap(values, &mut split); + build_primitive(split, null_group) + } + }; + + Ok(vec![Arc::new(array.with_data_type(data_type))]) +}