From e3477f1039f724ca37c4dfb8b528f8ede750d8df Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 21 May 2025 16:42:09 +0800 Subject: [PATCH 1/6] unnecessary to save hash value... let's save value... --- .../aggregates/group_values/single_group_by/primitive.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 279caa50b0a6..20dab5c97607 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -88,7 +88,7 @@ pub struct GroupValuesPrimitive { /// More details can see: /// /// - map: HashTable<(usize, u64)>, + map: HashTable<(usize, T::Native)>, /// The group index of the null value if any null_group: Option, /// The values for each group index @@ -130,15 +130,15 @@ where let hash = key.hash(state); let insert = self.map.entry( hash, - |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, - |&(_, h)| h, + |&(_, v)| v.is_eq(key), + |&(_, v)| v.hash(state), ); match insert { hashbrown::hash_table::Entry::Occupied(o) => o.get().0, hashbrown::hash_table::Entry::Vacant(v) => { let g = self.values.len(); - v.insert((g, hash)); + v.insert((g, key)); self.values.push(key); g } From 93045f2888531baa06461fb1ab60083fa425cc24 Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 21 May 2025 19:51:27 +0800 Subject: [PATCH 2/6] specialized `GroupValues` for `primitive` and `large_primitive` for performance. --- .../src/aggregates/group_values/mod.rs | 13 +- .../primitive/large_primitive.rs | 143 ++++++++++++++ .../{primitive.rs => primitive/mod.rs} | 184 ++++++++++-------- 3 files changed, 255 insertions(+), 85 deletions(-) create mode 100644 datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs rename datafusion/physical-plan/src/aggregates/group_values/single_group_by/{primitive.rs => primitive/mod.rs} (72%) diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ce56ca4f7dfd..e595fe4aa245 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -40,8 +40,9 @@ pub(crate) use single_group_by::primitive::HashValue; use crate::aggregates::{ group_values::single_group_by::{ - bytes::GroupValuesByes, bytes_view::GroupValuesBytesView, - primitive::GroupValuesPrimitive, + bytes::GroupValuesByes, + bytes_view::GroupValuesBytesView, + primitive::{GroupValuesLargePrimitive, GroupValuesPrimitive}, }, order::GroupOrdering, }; @@ -134,6 +135,12 @@ pub(crate) fn new_group_values( if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); + macro_rules! large_downcast_helper { + ($t:ty, $d:ident) => { + return Ok(Box::new(GroupValuesLargePrimitive::<$t>::new($d.clone()))) + }; + } + macro_rules! downcast_helper { ($t:ty, $d:ident) => { return Ok(Box::new(GroupValuesPrimitive::<$t>::new($d.clone()))) @@ -169,7 +176,7 @@ pub(crate) fn new_group_values( TimeUnit::Nanosecond => downcast_helper!(TimestampNanosecondType, d), }, DataType::Decimal128(_, _) => { - downcast_helper!(Decimal128Type, d); + large_downcast_helper!(Decimal128Type, d); } DataType::Utf8 => { return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs new file mode 100644 index 000000000000..a0c0047a07c7 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::single_group_by::primitive::{ + emit_internal, HashValue, +}; +use crate::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; +use arrow::array::{ + cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, + PrimitiveArray, +}; +use arrow::datatypes::{i256, DataType}; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_expr::EmitTo; +use half::f16; +use hashbrown::hash_table::HashTable; +use std::mem::size_of; +use std::sync::Arc; + +/// A [`GroupValues`] storing a single column of primitive values +/// +/// This specialization is significantly faster than using the more general +/// purpose `Row`s format +pub struct GroupValuesLargePrimitive { + /// The data type of the output array + data_type: DataType, + /// Stores the `(group_index, hash)` based on the hash of its value + /// + /// We also store `hash` is for reducing cost of rehashing. Such cost + /// is obvious in high cardinality group by situation. + /// More details can see: + /// + /// + map: HashTable<(usize, u64)>, + /// The group index of the null value if any + null_group: Option, + /// The values for each group index + values: Vec, + /// The random state used to generate hashes + random_state: RandomState, +} + +impl GroupValuesLargePrimitive { + pub fn new(data_type: DataType) -> Self { + assert!(PrimitiveArray::::is_compatible(&data_type)); + Self { + data_type, + map: HashTable::with_capacity(128), + values: Vec::with_capacity(128), + null_group: None, + random_state: Default::default(), + } + } +} + +impl GroupValues for GroupValuesLargePrimitive +where + T::Native: HashValue, +{ + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + assert_eq!(cols.len(), 1); + groups.clear(); + + for v in cols[0].as_primitive::() { + let group_id = match v { + None => *self.null_group.get_or_insert_with(|| { + let group_id = self.values.len(); + self.values.push(Default::default()); + group_id + }), + Some(key) => { + let state = &self.random_state; + let hash = key.hash(state); + let insert = self.map.entry( + hash, + |&(g, _)| unsafe { self.values.get_unchecked(g).is_eq(key) }, + |&(_, h)| h, + ); + + match insert { + hashbrown::hash_table::Entry::Occupied(o) => o.get().0, + hashbrown::hash_table::Entry::Vacant(v) => { + let g = self.values.len(); + v.insert((g, hash)); + self.values.push(key); + g + } + } + } + }; + groups.push(group_id) + } + Ok(()) + } + + fn size(&self) -> usize { + self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() + } + + fn is_empty(&self) -> bool { + self.values.is_empty() + } + + fn len(&self) -> usize { + self.values.len() + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + emit_internal::( + emit_to, + &mut self.values, + &mut self.null_group, + &mut self.map, + self.data_type.clone(), + ) + } + + fn clear_shrink(&mut self, batch: &RecordBatch) { + let count = batch.num_rows(); + self.values.clear(); + self.values.shrink_to(count); + self.map.clear(); + self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs similarity index 72% rename from datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs rename to datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs index 20dab5c97607..6dd0accbc6a8 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs @@ -1,36 +1,4 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::aggregates::group_values::GroupValues; -use ahash::RandomState; -use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; -use arrow::array::{ - cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, - PrimitiveArray, -}; -use arrow::datatypes::{i256, DataType}; -use arrow::record_batch::RecordBatch; -use datafusion_common::Result; -use datafusion_execution::memory_pool::proxy::VecAllocExt; -use datafusion_expr::EmitTo; -use half::f16; -use hashbrown::hash_table::HashTable; -use std::mem::size_of; -use std::sync::Arc; +mod large_primitive; /// A trait to allow hashing of floating point numbers pub(crate) trait HashValue { @@ -74,6 +42,42 @@ macro_rules! hash_float { hash_float!(f16, f32, f64); +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; +use arrow::array::{ + cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, + PrimitiveArray, +}; +use arrow::datatypes::{i256, DataType}; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_expr::EmitTo; +use half::f16; +use hashbrown::hash_table::HashTable; +use std::mem::size_of; +use std::sync::Arc; + +pub use large_primitive::GroupValuesLargePrimitive; + /// A [`GroupValues`] storing a single column of primitive values /// /// This specialization is significantly faster than using the more general @@ -163,55 +167,13 @@ where } fn emit(&mut self, emit_to: EmitTo) -> Result> { - fn build_primitive( - values: Vec, - null_idx: Option, - ) -> PrimitiveArray { - let nulls = null_idx.map(|null_idx| { - let mut buffer = NullBufferBuilder::new(values.len()); - buffer.append_n_non_nulls(null_idx); - buffer.append_null(); - buffer.append_n_non_nulls(values.len() - null_idx - 1); - // NOTE: The inner builder must be constructed as there is at least one null - buffer.finish().unwrap() - }); - PrimitiveArray::::new(values.into(), nulls) - } - - let array: PrimitiveArray = match emit_to { - EmitTo::All => { - self.map.clear(); - build_primitive(std::mem::take(&mut self.values), self.null_group.take()) - } - EmitTo::First(n) => { - self.map.retain(|entry| { - // Decrement group index by n - let group_idx = entry.0; - match group_idx.checked_sub(n) { - // Group index was >= n, shift value down - Some(sub) => { - entry.0 = sub; - true - } - // Group index was < n, so remove from table - None => false, - } - }); - let null_group = match &mut self.null_group { - Some(v) if *v >= n => { - *v -= n; - None - } - Some(_) => self.null_group.take(), - None => None, - }; - let mut split = self.values.split_off(n); - std::mem::swap(&mut self.values, &mut split); - build_primitive(split, null_group) - } - }; - - Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) + emit_internal::( + emit_to, + &mut self.values, + &mut self.null_group, + &mut self.map, + self.data_type.clone(), + ) } fn clear_shrink(&mut self, batch: &RecordBatch) { @@ -222,3 +184,61 @@ where self.map.shrink_to(count, |_| 0); // hasher does not matter since the map is cleared } } + +pub(crate) fn emit_internal( + emit_to: EmitTo, + values: &mut Vec, + null_group: &mut Option, + map: &mut HashTable<(usize, K)>, + data_type: DataType, +) -> Result> { + fn build_primitive( + values: Vec, + null_idx: Option, + ) -> PrimitiveArray { + let nulls = null_idx.map(|null_idx| { + let mut buffer = NullBufferBuilder::new(values.len()); + buffer.append_n_non_nulls(null_idx); + buffer.append_null(); + buffer.append_n_non_nulls(values.len() - null_idx - 1); + // NOTE: The inner builder must be constructed as there is at least one null + buffer.finish().unwrap() + }); + PrimitiveArray::::new(values.into(), nulls) + } + + let array: PrimitiveArray = match emit_to { + EmitTo::All => { + map.clear(); + build_primitive(std::mem::take(values), null_group.take()) + } + EmitTo::First(n) => { + map.retain(|entry| { + // Decrement group index by n + let group_idx = entry.0; + match group_idx.checked_sub(n) { + // Group index was >= n, shift value down + Some(sub) => { + entry.0 = sub; + true + } + // Group index was < n, so remove from table + None => false, + } + }); + let null_group = match null_group { + Some(v) if *v >= n => { + *v -= n; + None + } + Some(_) => null_group.take(), + None => None, + }; + let mut split = values.split_off(n); + std::mem::swap(values, &mut split); + build_primitive(split, null_group) + } + }; + + Ok(vec![Arc::new(array.with_data_type(data_type))]) +} From 6ce4857381e401b5e52e370c563a54f121422a7d Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 21 May 2025 20:10:04 +0800 Subject: [PATCH 3/6] fix comments and clippy. --- .../single_group_by/primitive/large_primitive.rs | 11 ++++------- .../group_values/single_group_by/primitive/mod.rs | 10 +++++----- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs index a0c0047a07c7..750b2deb3b8a 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs @@ -20,22 +20,19 @@ use crate::aggregates::group_values::single_group_by::primitive::{ }; use crate::aggregates::group_values::GroupValues; use ahash::RandomState; -use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; +use arrow::array::types; use arrow::array::{ - cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, - PrimitiveArray, + cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, }; -use arrow::datatypes::{i256, DataType}; +use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use datafusion_common::Result; use datafusion_execution::memory_pool::proxy::VecAllocExt; use datafusion_expr::EmitTo; -use half::f16; use hashbrown::hash_table::HashTable; use std::mem::size_of; -use std::sync::Arc; -/// A [`GroupValues`] storing a single column of primitive values +/// A [`GroupValues`] storing a single column of large primitive values (bits > 64) /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs index 6dd0accbc6a8..4a575039a186 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs @@ -78,19 +78,19 @@ use std::sync::Arc; pub use large_primitive::GroupValuesLargePrimitive; -/// A [`GroupValues`] storing a single column of primitive values +/// A [`GroupValues`] storing a single column of normal primitive values (bits <= 64) /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format pub struct GroupValuesPrimitive { /// The data type of the output array data_type: DataType, - /// Stores the `(group_index, hash)` based on the hash of its value + /// Stores the `(group_index, group_value)` /// - /// We also store `hash` is for reducing cost of rehashing. Such cost - /// is obvious in high cardinality group by situation. + /// We directly store copy of `group_value` for not only efficient + /// rehashing, but also efficient probing. /// More details can see: - /// + /// /// map: HashTable<(usize, T::Native)>, /// The group index of the null value if any From 1b5cde90f9d2206baede1dbe8815e806aca03662 Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 21 May 2025 20:16:50 +0800 Subject: [PATCH 4/6] fix `size` of `GroupValuesPrimitive`. --- .../single_group_by/primitive/mod.rs | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs index 4a575039a186..693cc997fa3f 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs @@ -1,4 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::GroupValues; +use ahash::RandomState; +use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; +use arrow::array::{ + cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, + PrimitiveArray, +}; +use arrow::datatypes::{i256, DataType}; +use arrow::record_batch::RecordBatch; +use datafusion_common::Result; +use datafusion_execution::memory_pool::proxy::VecAllocExt; +use datafusion_expr::EmitTo; +use half::f16; +use hashbrown::hash_table::HashTable; +use std::mem::size_of; +use std::sync::Arc; + mod large_primitive; +pub use large_primitive::GroupValuesLargePrimitive; /// A trait to allow hashing of floating point numbers pub(crate) trait HashValue { @@ -42,42 +77,6 @@ macro_rules! hash_float { hash_float!(f16, f32, f64); -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::aggregates::group_values::GroupValues; -use ahash::RandomState; -use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; -use arrow::array::{ - cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, NullBufferBuilder, - PrimitiveArray, -}; -use arrow::datatypes::{i256, DataType}; -use arrow::record_batch::RecordBatch; -use datafusion_common::Result; -use datafusion_execution::memory_pool::proxy::VecAllocExt; -use datafusion_expr::EmitTo; -use half::f16; -use hashbrown::hash_table::HashTable; -use std::mem::size_of; -use std::sync::Arc; - -pub use large_primitive::GroupValuesLargePrimitive; - /// A [`GroupValues`] storing a single column of normal primitive values (bits <= 64) /// /// This specialization is significantly faster than using the more general @@ -155,7 +154,8 @@ where } fn size(&self) -> usize { - self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() + self.map.capacity() * size_of::<(usize, T::Native)>() + + self.values.allocated_size() } fn is_empty(&self) -> bool { From 8c05f694d6849a364bfb78400f06b5b4ca14c02c Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 21 May 2025 20:45:17 +0800 Subject: [PATCH 5/6] fix clippy. --- .../group_values/single_group_by/primitive/large_primitive.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs index 750b2deb3b8a..6e8cc23ee240 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/large_primitive.rs @@ -20,7 +20,6 @@ use crate::aggregates::group_values::single_group_by::primitive::{ }; use crate::aggregates::group_values::GroupValues; use ahash::RandomState; -use arrow::array::types; use arrow::array::{ cast::AsArray, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, PrimitiveArray, }; From cf053cb1bdaa85e3caa528e4b33f5bc0a7cbc1f0 Mon Sep 17 00:00:00 2001 From: kamille Date: Thu, 22 May 2025 17:47:10 +0800 Subject: [PATCH 6/6] try extend. --- .../single_group_by/primitive/mod.rs | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs index 693cc997fa3f..7fa4638d12e2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive/mod.rs @@ -98,6 +98,8 @@ pub struct GroupValuesPrimitive { values: Vec, /// The random state used to generate hashes random_state: RandomState, + + append_row_indices: Vec, } impl GroupValuesPrimitive { @@ -109,6 +111,7 @@ impl GroupValuesPrimitive { values: Vec::with_capacity(128), null_group: None, random_state: Default::default(), + append_row_indices: Vec::new(), } } } @@ -119,13 +122,18 @@ where { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { assert_eq!(cols.len(), 1); + let col = cols[0].as_primitive::(); + groups.clear(); + self.append_row_indices.clear(); - for v in cols[0].as_primitive::() { + let mut num_total_groups = self.values.len(); + for (row_index, v) in col.iter().enumerate() { let group_id = match v { None => *self.null_group.get_or_insert_with(|| { - let group_id = self.values.len(); - self.values.push(Default::default()); + let group_id = num_total_groups; + self.append_row_indices.push(row_index as u32); + num_total_groups += 1; group_id }), Some(key) => { @@ -140,9 +148,10 @@ where match insert { hashbrown::hash_table::Entry::Occupied(o) => o.get().0, hashbrown::hash_table::Entry::Vacant(v) => { - let g = self.values.len(); + let g = num_total_groups; v.insert((g, key)); - self.values.push(key); + self.append_row_indices.push(row_index as u32); + num_total_groups += 1; g } } @@ -150,6 +159,17 @@ where }; groups.push(group_id) } + + // If all are new groups, we just extend it + if self.append_row_indices.len() == col.len() { + self.values.extend_from_slice(col.values()); + } else { + let col_values = col.values(); + for &row_index in self.append_row_indices.iter() { + self.values.push(col_values[row_index as usize]); + } + } + Ok(()) }