diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ae528daad53c..58bc7bb90a88 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -164,7 +164,7 @@ pub(crate) fn new_group_values( TimeUnit::Nanosecond => downcast_helper!(Time64NanosecondType, d), _ => {} }, - DataType::Timestamp(t, _) => match t { + DataType::Timestamp(t, _tz) => match t { TimeUnit::Second => downcast_helper!(TimestampSecondType, d), TimeUnit::Millisecond => downcast_helper!(TimestampMillisecondType, d), TimeUnit::Microsecond => downcast_helper!(TimestampMicrosecondType, d), diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 10b00cf74fdb..89041eb0f04e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -880,12 +880,12 @@ impl GroupValuesColumn { /// `$t`: the primitive type of the builder /// macro_rules! instantiate_primitive { - ($v:expr, $nullable:expr, $t:ty) => { + ($v:expr, $nullable:expr, $t:ty, $data_type:ident) => { if $nullable { - let b = PrimitiveGroupValueBuilder::<$t, true>::new(); + let b = PrimitiveGroupValueBuilder::<$t, true>::new($data_type.to_owned()); $v.push(Box::new(b) as _) } else { - let b = PrimitiveGroupValueBuilder::<$t, false>::new(); + let b = PrimitiveGroupValueBuilder::<$t, false>::new($data_type.to_owned()); $v.push(Box::new(b) as _) } }; @@ -898,53 +898,114 @@ impl GroupValues for GroupValuesColumn { for f in self.schema.fields().iter() { let nullable = f.is_nullable(); - match f.data_type() { - &DataType::Int8 => instantiate_primitive!(v, nullable, Int8Type), - &DataType::Int16 => instantiate_primitive!(v, nullable, Int16Type), - &DataType::Int32 => instantiate_primitive!(v, nullable, Int32Type), - &DataType::Int64 => instantiate_primitive!(v, nullable, Int64Type), - &DataType::UInt8 => instantiate_primitive!(v, nullable, UInt8Type), - &DataType::UInt16 => instantiate_primitive!(v, nullable, UInt16Type), - &DataType::UInt32 => instantiate_primitive!(v, nullable, UInt32Type), - &DataType::UInt64 => instantiate_primitive!(v, nullable, UInt64Type), + let data_type = f.data_type(); + match data_type { + &DataType::Int8 => { + instantiate_primitive!(v, nullable, Int8Type, data_type) + } + &DataType::Int16 => { + instantiate_primitive!(v, nullable, Int16Type, data_type) + } + &DataType::Int32 => { + instantiate_primitive!(v, nullable, Int32Type, data_type) + } + &DataType::Int64 => { + instantiate_primitive!(v, nullable, Int64Type, data_type) + } + &DataType::UInt8 => { + instantiate_primitive!(v, nullable, UInt8Type, data_type) + } + &DataType::UInt16 => { + instantiate_primitive!(v, nullable, UInt16Type, data_type) + } + &DataType::UInt32 => { + instantiate_primitive!(v, nullable, UInt32Type, data_type) + } + &DataType::UInt64 => { + instantiate_primitive!(v, nullable, UInt64Type, data_type) + } &DataType::Float32 => { - instantiate_primitive!(v, nullable, Float32Type) + instantiate_primitive!(v, nullable, Float32Type, data_type) } &DataType::Float64 => { - instantiate_primitive!(v, nullable, Float64Type) + instantiate_primitive!(v, nullable, Float64Type, data_type) + } + &DataType::Date32 => { + instantiate_primitive!(v, nullable, Date32Type, data_type) + } + &DataType::Date64 => { + instantiate_primitive!(v, nullable, Date64Type, data_type) } - &DataType::Date32 => instantiate_primitive!(v, nullable, Date32Type), - &DataType::Date64 => instantiate_primitive!(v, nullable, Date64Type), &DataType::Time32(t) => match t { TimeUnit::Second => { - instantiate_primitive!(v, nullable, Time32SecondType) + instantiate_primitive!( + v, + nullable, + Time32SecondType, + data_type + ) } TimeUnit::Millisecond => { - instantiate_primitive!(v, nullable, Time32MillisecondType) + instantiate_primitive!( + v, + nullable, + Time32MillisecondType, + data_type + ) } _ => {} }, &DataType::Time64(t) => match t { TimeUnit::Microsecond => { - instantiate_primitive!(v, nullable, Time64MicrosecondType) + instantiate_primitive!( + v, + nullable, + Time64MicrosecondType, + data_type + ) } TimeUnit::Nanosecond => { - instantiate_primitive!(v, nullable, Time64NanosecondType) + instantiate_primitive!( + v, + nullable, + Time64NanosecondType, + data_type + ) } _ => {} }, &DataType::Timestamp(t, _) => match t { TimeUnit::Second => { - instantiate_primitive!(v, nullable, TimestampSecondType) + instantiate_primitive!( + v, + nullable, + TimestampSecondType, + data_type + ) } TimeUnit::Millisecond => { - instantiate_primitive!(v, nullable, TimestampMillisecondType) + instantiate_primitive!( + v, + nullable, + TimestampMillisecondType, + data_type + ) } TimeUnit::Microsecond => { - instantiate_primitive!(v, nullable, TimestampMicrosecondType) + instantiate_primitive!( + v, + nullable, + TimestampMicrosecondType, + data_type + ) } TimeUnit::Nanosecond => { - instantiate_primitive!(v, nullable, TimestampNanosecondType) + instantiate_primitive!( + v, + nullable, + TimestampNanosecondType, + data_type + ) } }, &DataType::Utf8 => { diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index 4da482247458..4686a78f24b0 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -20,6 +20,7 @@ use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::buffer::ScalarBuffer; use arrow_array::cast::AsArray; use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; +use arrow_schema::DataType; use datafusion_execution::memory_pool::proxy::VecAllocExt; use itertools::izip; use std::iter; @@ -35,6 +36,7 @@ use std::sync::Arc; /// `NULLABLE`: if the data can contain any nulls #[derive(Debug)] pub struct PrimitiveGroupValueBuilder { + data_type: DataType, group_values: Vec, nulls: MaybeNullBufferBuilder, } @@ -44,8 +46,9 @@ where T: ArrowPrimitiveType, { /// Create a new `PrimitiveGroupValueBuilder` - pub fn new() -> Self { + pub fn new(data_type: DataType) -> Self { Self { + data_type, group_values: vec![], nulls: MaybeNullBufferBuilder::new(), } @@ -177,6 +180,7 @@ impl GroupColumn fn build(self: Box) -> ArrayRef { let Self { + data_type, group_values, nulls, } = *self; @@ -186,10 +190,9 @@ impl GroupColumn assert!(nulls.is_none(), "unexpected nulls in non nullable input"); } - Arc::new(PrimitiveArray::::new( - ScalarBuffer::from(group_values), - nulls, - )) + let arr = PrimitiveArray::::new(ScalarBuffer::from(group_values), nulls); + // Set timezone information for timestamp + Arc::new(arr.with_data_type(data_type)) } fn take_n(&mut self, n: usize) -> ArrayRef { @@ -212,6 +215,7 @@ mod tests { use arrow::datatypes::Int64Type; use arrow_array::{ArrayRef, Int64Array}; use arrow_buffer::{BooleanBufferBuilder, NullBuffer}; + use arrow_schema::DataType; use super::GroupColumn; @@ -283,7 +287,8 @@ mod tests { // - exist not null, input not null; values equal // Define PrimitiveGroupValueBuilder - let mut builder = PrimitiveGroupValueBuilder::::new(); + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int64); let builder_array = Arc::new(Int64Array::from(vec![ None, None, @@ -392,7 +397,8 @@ mod tests { // - values not equal // Define PrimitiveGroupValueBuilder - let mut builder = PrimitiveGroupValueBuilder::::new(); + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int64); let builder_array = Arc::new(Int64Array::from(vec![Some(0), Some(1)])) as ArrayRef; append(&mut builder, &builder_array, &[0, 1]); @@ -419,7 +425,8 @@ mod tests { // Test the special `all nulls` or `not nulls` input array case // for vectorized append and equal to - let mut builder = PrimitiveGroupValueBuilder::::new(); + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int64); // All nulls input array let all_nulls_input_array = Arc::new(Int64Array::from(vec![ diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 5d8c4dfd05b4..4acf519c5de4 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -5483,3 +5483,19 @@ SELECT max(input_table.x), min(input_table.x) from input_table GROUP BY input_ta ---- NaN NaN +# Group by timestamp +query TP +SELECT + 'foo' AS text, + arrow_cast('2024-01-01T00:00:00Z'::timestamptz, 'Timestamp(Microsecond, Some("UTC"))') AS ts +GROUP BY ts, text +---- +foo 2024-01-01T00:00:00Z + +query TP +SELECT + 'foo' AS text, + arrow_cast('2024-01-01T00:00:00Z'::timestamptz, 'Timestamp(Second, Some("+08:00"))') AS ts +GROUP BY ts, text +---- +foo 2024-01-01T08:00:00+08:00