From d58714b70f77a5f8b89639b180d89d64f504c42e Mon Sep 17 00:00:00 2001 From: Xiangpeng Hao Date: Sat, 3 Aug 2024 07:45:57 -0400 Subject: [PATCH] multi-col agg --- .../src/binary_view_map.rs | 66 +++++- .../src/aggregates/group_values/row.rs | 211 +++++++++++++----- 2 files changed, 209 insertions(+), 68 deletions(-) diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 18bc6801aa60..66ab64220827 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -24,7 +24,7 @@ use arrow::array::cast::AsArray; use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; use datafusion_common::hash_utils::create_hashes; -use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::RawTableAllocExt; use std::fmt::Debug; use std::sync::Arc; @@ -207,6 +207,7 @@ where values, make_payload_fn, observe_payload_fn, + None, ) } OutputType::Utf8View => { @@ -215,6 +216,43 @@ where values, make_payload_fn, observe_payload_fn, + None, + ) + } + _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), + }; + } + + /// Similar to [`Self::insert_if_new`] but allows the caller to provide the + /// hash values for the values in `values` instead of computing them + pub fn insert_if_new_with_hash( + &mut self, + values: &ArrayRef, + make_payload_fn: MP, + observe_payload_fn: OP, + provided_hash: &Vec, + ) where + MP: FnMut(Option<&[u8]>) -> V, + OP: FnMut(V), + { + // Sanity check array type + match self.output_type { + OutputType::BinaryView => { + assert!(matches!(values.data_type(), DataType::BinaryView)); + self.insert_if_new_inner::( + values, + make_payload_fn, + observe_payload_fn, + Some(provided_hash), + ) + } + OutputType::Utf8View => { + assert!(matches!(values.data_type(), DataType::Utf8View)); + self.insert_if_new_inner::( + values, + make_payload_fn, + observe_payload_fn, + Some(provided_hash), ) } _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), @@ -234,19 +272,26 @@ where values: &ArrayRef, mut make_payload_fn: MP, mut observe_payload_fn: OP, + provided_hash: Option<&Vec>, ) where MP: FnMut(Option<&[u8]>) -> V, OP: FnMut(V), B: ByteViewType, { // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); + let batch_hashes = match provided_hash { + Some(h) => h, + None => { + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + batch_hashes + } + }; // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); @@ -353,9 +398,7 @@ where /// Return the total size, in bytes, of memory used to store the data in /// this set, not including `self` pub fn size(&self) -> usize { - self.map_size - + self.builder.allocated_size() - + self.hashes_buffer.allocated_size() + self.map_size + self.builder.allocated_size() } } @@ -369,7 +412,6 @@ where .field("map_size", &self.map_size) .field("view_builder", &self.builder) .field("random_state", &self.random_state) - .field("hashes_buffer", &self.hashes_buffer) .finish() } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index dc948e28bb2d..a0f75716d162 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -15,18 +15,40 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::aggregates::group_values::GroupValues; use ahash::RandomState; +use arrow::array::AsArray as _; use arrow::compute::cast; +use arrow::datatypes::UInt32Type; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, Rows, SortField}; -use arrow_array::{Array, ArrayRef}; +use arrow_array::{Array, ArrayRef, StringViewArray}; use arrow_schema::{DataType, SchemaRef}; -use datafusion_common::hash_utils::create_hashes; +use datafusion_common::hash_utils::{combine_hashes, create_hashes}; use datafusion_common::{DataFusionError, Result}; -use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; +use datafusion_execution::memory_pool::proxy::RawTableAllocExt; use datafusion_expr::EmitTo; +use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; use hashbrown::raw::RawTable; +use itertools::Itertools; + +struct VarLenGroupValues { + map: ArrowBytesViewMap, + num_groups: u32, +} + +impl VarLenGroupValues { + fn new() -> Self { + Self { + map: ArrowBytesViewMap::new( + datafusion_physical_expr::binary_map::OutputType::Utf8View, + ), + num_groups: 0, + } + } +} /// A [`GroupValues`] making use of [`Rows`] pub struct GroupValuesRows { @@ -60,22 +82,35 @@ pub struct GroupValuesRows { group_values: Option, /// reused buffer to store hashes - hashes_buffer: Vec, + final_hash_buffer: Vec, + + tmp_hash_buffer: Vec, /// reused buffer to store rows rows_buffer: Rows, + // variable length column map + var_len_map: Vec, + /// Random state for creating hashes random_state: RandomState, } impl GroupValuesRows { pub fn try_new(schema: SchemaRef) -> Result { + let mut var_len_map = Vec::new(); let row_converter = RowConverter::new( schema .fields() .iter() - .map(|f| SortField::new(f.data_type().clone())) + .map(|f| { + if f.data_type() == &DataType::Utf8View { + var_len_map.push(VarLenGroupValues::new()); + SortField::new(DataType::UInt32) + } else { + SortField::new(f.data_type().clone()) + } + }) .collect(), )?; @@ -91,20 +126,120 @@ impl GroupValuesRows { map, map_size: 0, group_values: None, - hashes_buffer: Default::default(), + final_hash_buffer: Default::default(), + tmp_hash_buffer: Default::default(), rows_buffer, + var_len_map, random_state: Default::default(), }) } + + fn transform_col_to_fixed_len(&mut self, input: &[ArrayRef]) -> Vec { + let n_rows = input[0].len(); + // 1.1 Calculate the group keys for the group values + let final_hash_buffer = &mut self.final_hash_buffer; + final_hash_buffer.clear(); + final_hash_buffer.resize(n_rows, 0); + let tmp_hash_buffer = &mut self.tmp_hash_buffer; + tmp_hash_buffer.clear(); + tmp_hash_buffer.resize(n_rows, 0); + + let mut cur_var_len_idx = 0; + let transformed_cols: Vec = input + .iter() + .map(|c| { + if let DataType::Utf8View = c.data_type() { + create_hashes(&[Arc::clone(c)], &self.random_state, tmp_hash_buffer) + .unwrap(); + let mut var_groups = Vec::with_capacity(c.len()); + let group_values = &mut self.var_len_map[cur_var_len_idx]; + group_values.map.insert_if_new_with_hash( + c, + |_value| { + let group_idx = group_values.num_groups; + group_values.num_groups += 1; + group_idx + }, + |group_idx| { + var_groups.push(group_idx); + }, + tmp_hash_buffer, + ); + cur_var_len_idx += 1; + final_hash_buffer + .iter_mut() + .zip(tmp_hash_buffer.iter()) + .for_each(|(result, tmp)| { + *result = combine_hashes(*result, *tmp); + }); + std::sync::Arc::new(arrow_array::UInt32Array::from(var_groups)) + as ArrayRef + } else { + create_hashes(&[Arc::clone(c)], &self.random_state, tmp_hash_buffer) + .unwrap(); + final_hash_buffer + .iter_mut() + .zip(tmp_hash_buffer.iter()) + .for_each(|(result, tmp)| { + *result = combine_hashes(*result, *tmp); + }); + Arc::clone(c) + } + }) + .collect(); + transformed_cols + } + + fn transform_col_to_var_len(&mut self, output: Vec) -> Vec { + let mut cur_var_len_idx = 0; + let output = output + .into_iter() + .enumerate() + .map(|(i, array)| { + let data_type = self.schema.field(i).data_type(); + if data_type == &DataType::Utf8View { + let arr = array.as_primitive::(); + let mut views = Vec::with_capacity(arr.len()); + + let map_content = + &mut self.var_len_map[cur_var_len_idx].map.take().into_state(); + let map_content = map_content.as_string_view(); + + for v in arr.iter() { + if let Some(index) = v { + let value = unsafe { + map_content.views().get_unchecked(index as usize) + }; + views.push(*value); + } else { + views.push(0); + } + } + let output_str = unsafe { + StringViewArray::new_unchecked( + views.into(), + map_content.data_buffers().to_vec(), + map_content.nulls().cloned(), + ) + }; + cur_var_len_idx += 1; + Arc::new(output_str) as ArrayRef + } else { + array + } + }) + .collect_vec(); + output + } } impl GroupValues for GroupValuesRows { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - // Convert the group keys into the row format + let transformed_cols: Vec = self.transform_col_to_fixed_len(cols); + let group_rows = &mut self.rows_buffer; group_rows.clear(); - self.row_converter.append(group_rows, cols)?; - let n_rows = group_rows.num_rows(); + self.row_converter.append(group_rows, &transformed_cols)?; let mut group_values = match self.group_values.take() { Some(group_values) => group_values, @@ -114,25 +249,13 @@ impl GroupValues for GroupValuesRows { // tracks to which group each of the input rows belongs groups.clear(); - // 1.1 Calculate the group keys for the group values - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(n_rows, 0); - create_hashes(cols, &self.random_state, batch_hashes)?; - - for (row, &target_hash) in batch_hashes.iter().enumerate() { - let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { - // Somewhat surprisingly, this closure can be called even if the - // hash doesn't match, so check the hash first with an integer - // comparison first avoid the more expensive comparison with - // group value. https://github.com/apache/datafusion/pull/11718 - target_hash == *exist_hash - // verify that the group that we are inserting with hash is - // actually the same key value as the group in - // existing_idx (aka group_values @ row) - && group_rows.row(row) == group_values.row(*group_idx) + for (row, hash) in group_rows.iter().zip(self.final_hash_buffer.iter()) { + let entry = self.map.get_mut(*hash, |(_hash, group_idx)| { + // verify that a group that we are inserting with hash is + // actually the same key value as the group in + // existing_idx (aka group_values @ row) + row == group_values.row(*group_idx) }); - let group_idx = match entry { // Existing group_index for this group value Some((_hash, group_idx)) => *group_idx, @@ -140,11 +263,11 @@ impl GroupValues for GroupValuesRows { None => { // Add new entry to aggr_state and save newly created index let group_idx = group_values.num_rows(); - group_values.push(group_rows.row(row)); + group_values.push(row); // for hasher function, use precomputed hash value self.map.insert_accounted( - (target_hash, group_idx), + (*hash, group_idx), |(hash, _group_index)| *hash, &mut self.map_size, ); @@ -165,7 +288,6 @@ impl GroupValues for GroupValuesRows { + group_values_size + self.map_size + self.rows_buffer.size() - + self.hashes_buffer.allocated_size() } fn is_empty(&self) -> bool { @@ -188,33 +310,12 @@ impl GroupValues for GroupValuesRows { let mut output = match emit_to { EmitTo::All => { let output = self.row_converter.convert_rows(&group_values)?; + let output = self.transform_col_to_var_len(output); group_values.clear(); output } - EmitTo::First(n) => { - let groups_rows = group_values.iter().take(n); - let output = self.row_converter.convert_rows(groups_rows)?; - // Clear out first n group keys by copying them to a new Rows. - // TODO file some ticket in arrow-rs to make this more efficient? - let mut new_group_values = self.row_converter.empty_rows(0, 0); - for row in group_values.iter().skip(n) { - new_group_values.push(row); - } - std::mem::swap(&mut new_group_values, &mut group_values); - - // SAFETY: self.map outlives iterator and is not modified concurrently - unsafe { - for bucket in self.map.iter() { - // Decrement group index by n - match bucket.as_ref().1.checked_sub(n) { - // Group index was >= n, shift value down - Some(sub) => bucket.as_mut().1 = sub, - // Group index was < n, so remove from table - None => self.map.erase(bucket), - } - } - } - output + EmitTo::First(_n) => { + unimplemented!("Not supported yet!") } }; @@ -245,7 +346,5 @@ impl GroupValues for GroupValuesRows { self.map.clear(); self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared self.map_size = self.map.capacity() * std::mem::size_of::<(u64, usize)>(); - self.hashes_buffer.clear(); - self.hashes_buffer.shrink_to(count); } }