diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 0e5fb82d9e93..faedea55ca15 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -56,7 +56,11 @@ impl TopKAggregation { } let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; - if !kt.is_primitive() && kt != DataType::Utf8 { + if !kt.is_primitive() + && kt != DataType::Utf8 + && kt != DataType::Utf8View + && kt != DataType::LargeUtf8 + { return None; } if aggr.filter_expr().iter().any(|e| e.is_some()) { diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index c818b4608de7..ae44eb35e6d0 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -23,7 +23,7 @@ use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::{ builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, Array, ArrayRef, - ArrowPrimitiveType, PrimitiveArray, StringArray, + ArrowPrimitiveType, LargeStringArray, PrimitiveArray, StringArray, StringViewArray, }; use arrow::datatypes::{i256, DataType}; use datafusion_common::DataFusionError; @@ -88,6 +88,7 @@ pub struct StringHashTable { owned: ArrayRef, map: TopKHashTable>, rnd: RandomState, + data_type: DataType, } // An implementation of ArrowHashTable for any `ArrowPrimitiveType` key @@ -101,13 +102,20 @@ where } impl StringHashTable { - pub fn new(limit: usize) -> Self { + pub fn new(limit: usize, data_type: DataType) -> Self { let vals: Vec<&str> = Vec::new(); - let owned = Arc::new(StringArray::from(vals)); + let owned: ArrayRef = match data_type { + DataType::Utf8 => Arc::new(StringArray::from(vals)), + DataType::Utf8View => Arc::new(StringViewArray::from(vals)), + DataType::LargeUtf8 => Arc::new(LargeStringArray::from(vals)), + _ => panic!("Unsupported data type"), + }; + Self { owned, map: TopKHashTable::new(limit, limit * 10), rnd: RandomState::default(), + data_type, } } } @@ -131,7 +139,12 @@ impl ArrowHashTable for StringHashTable { unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef { let ids = self.map.take_all(indexes); - Arc::new(StringArray::from(ids)) + match self.data_type { + DataType::Utf8 => Arc::new(StringArray::from(ids)), + DataType::LargeUtf8 => Arc::new(LargeStringArray::from(ids)), + DataType::Utf8View => Arc::new(StringViewArray::from(ids)), + _ => unreachable!(), + } } unsafe fn find_or_insert( @@ -140,15 +153,44 @@ impl ArrowHashTable for StringHashTable { replace_idx: usize, mapper: &mut Vec<(usize, usize)>, ) -> (usize, bool) { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("StringArray required"); - let id = if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) + let id = match self.data_type { + DataType::Utf8 => { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("Expected StringArray for DataType::Utf8"); + if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + } + } + DataType::LargeUtf8 => { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("Expected LargeStringArray for DataType::LargeUtf8"); + if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + } + } + DataType::Utf8View => { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("Expected StringViewArray for DataType::Utf8View"); + if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + } + } + _ => panic!("Unsupported data type"), }; let hash = self.rnd.hash_one(id); @@ -377,7 +419,9 @@ pub fn new_hash_table( downcast_primitive! { kt => (downcast_helper, kt), - DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))), + DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8))), + DataType::LargeUtf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::LargeUtf8))), + DataType::Utf8View => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8View))), _ => {} } diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index 3b954c4c72d3..25cf4251888d 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -108,11 +108,67 @@ impl PriorityMap { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Int64Array, RecordBatch, StringArray}; + use arrow::array::{ + Int64Array, LargeStringArray, RecordBatch, StringArray, StringViewArray, + }; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::util::pretty::pretty_format_batches; use std::sync::Arc; + #[test] + fn should_append_with_utf8view() -> Result<()> { + let ids: ArrayRef = Arc::new(StringViewArray::from(vec!["1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let mut agg = PriorityMap::new(DataType::Utf8View, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_utf8view(), cols)?; + let batch_schema = batch.schema(); + assert_eq!(batch_schema.fields[0].data_type(), &DataType::Utf8View); + + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + + #[test] + fn should_append_with_large_utf8() -> Result<()> { + let ids: ArrayRef = Arc::new(LargeStringArray::from(vec!["1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let mut agg = PriorityMap::new(DataType::LargeUtf8, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_large_schema(), cols)?; + let batch_schema = batch.schema(); + assert_eq!(batch_schema.fields[0].data_type(), &DataType::LargeUtf8); + + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + #[test] fn should_append() -> Result<()> { let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"])); @@ -370,4 +426,18 @@ mod tests { Field::new("timestamp_ms", DataType::Int64, true), ])) } + + fn test_schema_utf8view() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8View, true), + Field::new("timestamp_ms", DataType::Int64, true), + ])) + } + + fn test_large_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::LargeUtf8, true), + Field::new("timestamp_ms", DataType::Int64, true), + ])) + } } diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index 5fa0845cd2d5..cc1693843848 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -18,7 +18,6 @@ ####### # Setup test data table ####### - # TopK aggregation statement ok CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES @@ -214,5 +213,62 @@ a -1 -1 NULL 0 0 c 1 2 + +# Setting to map varchar to utf8view, to test PR https://github.com/apache/datafusion/pull/15152 +# Before the PR, the test case would not work because the Utf8View will not be supported by the TopK aggregation +statement ok +CREATE TABLE traces_utf8view +AS SELECT + arrow_cast(trace_id, 'Utf8View') as trace_id, + timestamp, + other +FROM traces; + +query TT +explain select trace_id, MAX(timestamp) from traces_utf8view group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +01)Sort: max(traces_utf8view.timestamp) DESC NULLS FIRST, fetch=4 +02)--Aggregate: groupBy=[[traces_utf8view.trace_id]], aggr=[[max(traces_utf8view.timestamp)]] +03)----TableScan: traces_utf8view projection=[trace_id, timestamp] +physical_plan +01)SortPreservingMergeExec: [max(traces_utf8view.timestamp)@1 DESC], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[max(traces_utf8view.timestamp)@1 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces_utf8view.timestamp)], lim=[4] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces_utf8view.timestamp)], lim=[4] +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] + + +# Also add LargeUtf8 to test PR https://github.com/apache/datafusion/pull/15152 +# Before the PR, the test case would not work because the LargeUtf8 will not be supported by the TopK aggregation +statement ok +CREATE TABLE traces_largeutf8 +AS SELECT + arrow_cast(trace_id, 'LargeUtf8') as trace_id, + timestamp, + other +FROM traces; + +query TT +explain select trace_id, MAX(timestamp) from traces_largeutf8 group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +01)Sort: max(traces_largeutf8.timestamp) DESC NULLS FIRST, fetch=4 +02)--Aggregate: groupBy=[[traces_largeutf8.trace_id]], aggr=[[max(traces_largeutf8.timestamp)]] +03)----TableScan: traces_largeutf8 projection=[trace_id, timestamp] +physical_plan +01)SortPreservingMergeExec: [max(traces_largeutf8.timestamp)@1 DESC], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[max(traces_largeutf8.timestamp)@1 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces_largeutf8.timestamp)], lim=[4] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces_largeutf8.timestamp)], lim=[4] +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] + + statement ok drop table traces;