From d5a9dad30938acbe91df08618012fe0686dd1b39 Mon Sep 17 00:00:00 2001 From: zhuqi-lucas <821684824@qq.com> Date: Tue, 11 Mar 2025 19:32:22 +0800 Subject: [PATCH 1/4] feat: topk functionality for aggregates should support utf8view --- .../src/aggregates/topk/hash_table.rs | 57 ++++++++++++++----- .../src/aggregates/topk/priority_map.rs | 36 +++++++++++- 2 files changed, 78 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index c818b4608de7..110f25fe166b 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -23,7 +23,7 @@ use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::{ builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, Array, ArrayRef, - ArrowPrimitiveType, PrimitiveArray, StringArray, + ArrowPrimitiveType, PrimitiveArray, StringArray, StringViewArray, }; use arrow::datatypes::{i256, DataType}; use datafusion_common::DataFusionError; @@ -88,6 +88,7 @@ pub struct StringHashTable { owned: ArrayRef, map: TopKHashTable>, rnd: RandomState, + data_type: DataType, } // An implementation of ArrowHashTable for any `ArrowPrimitiveType` key @@ -101,13 +102,19 @@ where } impl StringHashTable { - pub fn new(limit: usize) -> Self { + pub fn new(limit: usize, data_type: DataType) -> Self { let vals: Vec<&str> = Vec::new(); - let owned = Arc::new(StringArray::from(vals)); + let owned: ArrayRef = match data_type { + DataType::Utf8 => Arc::new(StringArray::from(vals)), + DataType::Utf8View => Arc::new(StringViewArray::from(vals)), + _ => panic!("Unsupported data type"), + }; + Self { owned, map: TopKHashTable::new(limit, limit * 10), rnd: RandomState::default(), + data_type, } } } @@ -131,7 +138,11 @@ impl ArrowHashTable for StringHashTable { unsafe fn take_all(&mut self, indexes: Vec) -> ArrayRef { let ids = self.map.take_all(indexes); - Arc::new(StringArray::from(ids)) + match self.data_type { + DataType::Utf8 => Arc::new(StringArray::from(ids)), + DataType::Utf8View => Arc::new(StringViewArray::from(ids)), + _ => unreachable!(), + } } unsafe fn find_or_insert( @@ -140,15 +151,32 @@ impl ArrowHashTable for StringHashTable { replace_idx: usize, mapper: &mut Vec<(usize, usize)>, ) -> (usize, bool) { - let ids = self - .owned - .as_any() - .downcast_ref::() - .expect("StringArray required"); - let id = if ids.is_null(row_idx) { - None - } else { - Some(ids.value(row_idx)) + let id = match self.data_type { + DataType::Utf8 => { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("Expected StringArray for DataType::Utf8"); + if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + } + } + DataType::Utf8View => { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("Expected StringViewArray for DataType::Utf8View"); + if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + } + } + _ => panic!("Unsupported data type"), }; let hash = self.rnd.hash_one(id); @@ -377,7 +405,8 @@ pub fn new_hash_table( downcast_primitive! { kt => (downcast_helper, kt), - DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit))), + DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8))), + DataType::Utf8View => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8View))), _ => {} } diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index 3b954c4c72d3..50b2a60a735e 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -108,11 +108,38 @@ impl PriorityMap { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Int64Array, RecordBatch, StringArray}; + use arrow::array::{Int64Array, RecordBatch, StringArray, StringViewArray}; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::util::pretty::pretty_format_batches; use std::sync::Arc; + #[test] + fn should_append_with_utf8view() -> Result<()> { + let ids: ArrayRef = Arc::new(StringViewArray::from(vec!["1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let mut agg = PriorityMap::new(DataType::Utf8View, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_schema_utf8view(), cols)?; + let batch_schema = batch.schema(); + assert_eq!(batch_schema.fields[0].data_type(), &DataType::Utf8View); + + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + #[test] fn should_append() -> Result<()> { let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"])); @@ -370,4 +397,11 @@ mod tests { Field::new("timestamp_ms", DataType::Int64, true), ])) } + + fn test_schema_utf8view() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::Utf8View, true), + Field::new("timestamp_ms", DataType::Int64, true), + ])) + } } From fbe9c75b78418143147a16e2ee42ff7103c921c9 Mon Sep 17 00:00:00 2001 From: zhuqi-lucas <821684824@qq.com> Date: Wed, 12 Mar 2025 23:25:04 +0800 Subject: [PATCH 2/4] Add testing for Utf8view slt --- datafusion/physical-optimizer/src/topk_aggregation.rs | 2 +- datafusion/sqllogictest/test_files/aggregates_topk.slt | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 0e5fb82d9e93..751c94162b35 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -56,7 +56,7 @@ impl TopKAggregation { } let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; - if !kt.is_primitive() && kt != DataType::Utf8 { + if !kt.is_primitive() && kt != DataType::Utf8 && kt != DataType::Utf8View { return None; } if aggr.filter_expr().iter().any(|e| e.is_some()) { diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index 5fa0845cd2d5..aa8762e6149f 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -19,6 +19,12 @@ # Setup test data table ####### +# Setting to map varchar to utf8view, to test PR https://github.com/apache/datafusion/pull/15152 +# Before the PR, the test case would not work because the Utf8View will not be supported by the TopK aggregation + +statement ok +set datafusion.sql_parser.map_varchar_to_utf8view = true; + # TopK aggregation statement ok CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES From 3ca6b1ce0485f69a65ec0655f48bea8c5739e55e Mon Sep 17 00:00:00 2001 From: zhuqi-lucas <821684824@qq.com> Date: Wed, 12 Mar 2025 23:52:21 +0800 Subject: [PATCH 3/4] Add large utf8 support --- .../src/topk_aggregation.rs | 6 ++- .../src/aggregates/topk/hash_table.rs | 17 ++++++++- .../src/aggregates/topk/priority_map.rs | 38 ++++++++++++++++++- 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 751c94162b35..faedea55ca15 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -56,7 +56,11 @@ impl TopKAggregation { } let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?; let kt = group_key.0.data_type(&aggr.input().schema()).ok()?; - if !kt.is_primitive() && kt != DataType::Utf8 && kt != DataType::Utf8View { + if !kt.is_primitive() + && kt != DataType::Utf8 + && kt != DataType::Utf8View + && kt != DataType::LargeUtf8 + { return None; } if aggr.filter_expr().iter().any(|e| e.is_some()) { diff --git a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs index 110f25fe166b..ae44eb35e6d0 100644 --- a/datafusion/physical-plan/src/aggregates/topk/hash_table.rs +++ b/datafusion/physical-plan/src/aggregates/topk/hash_table.rs @@ -23,7 +23,7 @@ use ahash::RandomState; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; use arrow::array::{ builder::PrimitiveBuilder, cast::AsArray, downcast_primitive, Array, ArrayRef, - ArrowPrimitiveType, PrimitiveArray, StringArray, StringViewArray, + ArrowPrimitiveType, LargeStringArray, PrimitiveArray, StringArray, StringViewArray, }; use arrow::datatypes::{i256, DataType}; use datafusion_common::DataFusionError; @@ -107,6 +107,7 @@ impl StringHashTable { let owned: ArrayRef = match data_type { DataType::Utf8 => Arc::new(StringArray::from(vals)), DataType::Utf8View => Arc::new(StringViewArray::from(vals)), + DataType::LargeUtf8 => Arc::new(LargeStringArray::from(vals)), _ => panic!("Unsupported data type"), }; @@ -140,6 +141,7 @@ impl ArrowHashTable for StringHashTable { let ids = self.map.take_all(indexes); match self.data_type { DataType::Utf8 => Arc::new(StringArray::from(ids)), + DataType::LargeUtf8 => Arc::new(LargeStringArray::from(ids)), DataType::Utf8View => Arc::new(StringViewArray::from(ids)), _ => unreachable!(), } @@ -164,6 +166,18 @@ impl ArrowHashTable for StringHashTable { Some(ids.value(row_idx)) } } + DataType::LargeUtf8 => { + let ids = self + .owned + .as_any() + .downcast_ref::() + .expect("Expected LargeStringArray for DataType::LargeUtf8"); + if ids.is_null(row_idx) { + None + } else { + Some(ids.value(row_idx)) + } + } DataType::Utf8View => { let ids = self .owned @@ -406,6 +420,7 @@ pub fn new_hash_table( downcast_primitive! { kt => (downcast_helper, kt), DataType::Utf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8))), + DataType::LargeUtf8 => return Ok(Box::new(StringHashTable::new(limit, DataType::LargeUtf8))), DataType::Utf8View => return Ok(Box::new(StringHashTable::new(limit, DataType::Utf8View))), _ => {} } diff --git a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs index 50b2a60a735e..25cf4251888d 100644 --- a/datafusion/physical-plan/src/aggregates/topk/priority_map.rs +++ b/datafusion/physical-plan/src/aggregates/topk/priority_map.rs @@ -108,7 +108,9 @@ impl PriorityMap { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Int64Array, RecordBatch, StringArray, StringViewArray}; + use arrow::array::{ + Int64Array, LargeStringArray, RecordBatch, StringArray, StringViewArray, + }; use arrow::datatypes::{Field, Schema, SchemaRef}; use arrow::util::pretty::pretty_format_batches; use std::sync::Arc; @@ -140,6 +142,33 @@ mod tests { Ok(()) } + #[test] + fn should_append_with_large_utf8() -> Result<()> { + let ids: ArrayRef = Arc::new(LargeStringArray::from(vec!["1"])); + let vals: ArrayRef = Arc::new(Int64Array::from(vec![1])); + let mut agg = PriorityMap::new(DataType::LargeUtf8, DataType::Int64, 1, false)?; + agg.set_batch(ids, vals); + agg.insert(0)?; + + let cols = agg.emit()?; + let batch = RecordBatch::try_new(test_large_schema(), cols)?; + let batch_schema = batch.schema(); + assert_eq!(batch_schema.fields[0].data_type(), &DataType::LargeUtf8); + + let actual = format!("{}", pretty_format_batches(&[batch])?); + let expected = r#" ++----------+--------------+ +| trace_id | timestamp_ms | ++----------+--------------+ +| 1 | 1 | ++----------+--------------+ + "# + .trim(); + assert_eq!(actual, expected); + + Ok(()) + } + #[test] fn should_append() -> Result<()> { let ids: ArrayRef = Arc::new(StringArray::from(vec!["1"])); @@ -404,4 +433,11 @@ mod tests { Field::new("timestamp_ms", DataType::Int64, true), ])) } + + fn test_large_schema() -> SchemaRef { + Arc::new(Schema::new(vec![ + Field::new("trace_id", DataType::LargeUtf8, true), + Field::new("timestamp_ms", DataType::Int64, true), + ])) + } } From ae172fc194f3d22b04b68b381bc5a25103ac6ea3 Mon Sep 17 00:00:00 2001 From: zhuqi-lucas <821684824@qq.com> Date: Thu, 13 Mar 2025 13:40:17 +0800 Subject: [PATCH 4/4] Address comments --- .../test_files/aggregates_topk.slt | 64 +++++++++++++++++-- 1 file changed, 57 insertions(+), 7 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregates_topk.slt b/datafusion/sqllogictest/test_files/aggregates_topk.slt index aa8762e6149f..cc1693843848 100644 --- a/datafusion/sqllogictest/test_files/aggregates_topk.slt +++ b/datafusion/sqllogictest/test_files/aggregates_topk.slt @@ -18,13 +18,6 @@ ####### # Setup test data table ####### - -# Setting to map varchar to utf8view, to test PR https://github.com/apache/datafusion/pull/15152 -# Before the PR, the test case would not work because the Utf8View will not be supported by the TopK aggregation - -statement ok -set datafusion.sql_parser.map_varchar_to_utf8view = true; - # TopK aggregation statement ok CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES @@ -220,5 +213,62 @@ a -1 -1 NULL 0 0 c 1 2 + +# Setting to map varchar to utf8view, to test PR https://github.com/apache/datafusion/pull/15152 +# Before the PR, the test case would not work because the Utf8View will not be supported by the TopK aggregation +statement ok +CREATE TABLE traces_utf8view +AS SELECT + arrow_cast(trace_id, 'Utf8View') as trace_id, + timestamp, + other +FROM traces; + +query TT +explain select trace_id, MAX(timestamp) from traces_utf8view group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +01)Sort: max(traces_utf8view.timestamp) DESC NULLS FIRST, fetch=4 +02)--Aggregate: groupBy=[[traces_utf8view.trace_id]], aggr=[[max(traces_utf8view.timestamp)]] +03)----TableScan: traces_utf8view projection=[trace_id, timestamp] +physical_plan +01)SortPreservingMergeExec: [max(traces_utf8view.timestamp)@1 DESC], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[max(traces_utf8view.timestamp)@1 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces_utf8view.timestamp)], lim=[4] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces_utf8view.timestamp)], lim=[4] +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] + + +# Also add LargeUtf8 to test PR https://github.com/apache/datafusion/pull/15152 +# Before the PR, the test case would not work because the LargeUtf8 will not be supported by the TopK aggregation +statement ok +CREATE TABLE traces_largeutf8 +AS SELECT + arrow_cast(trace_id, 'LargeUtf8') as trace_id, + timestamp, + other +FROM traces; + +query TT +explain select trace_id, MAX(timestamp) from traces_largeutf8 group by trace_id order by MAX(timestamp) desc limit 4; +---- +logical_plan +01)Sort: max(traces_largeutf8.timestamp) DESC NULLS FIRST, fetch=4 +02)--Aggregate: groupBy=[[traces_largeutf8.trace_id]], aggr=[[max(traces_largeutf8.timestamp)]] +03)----TableScan: traces_largeutf8 projection=[trace_id, timestamp] +physical_plan +01)SortPreservingMergeExec: [max(traces_largeutf8.timestamp)@1 DESC], fetch=4 +02)--SortExec: TopK(fetch=4), expr=[max(traces_largeutf8.timestamp)@1 DESC], preserve_partitioning=[true] +03)----AggregateExec: mode=FinalPartitioned, gby=[trace_id@0 as trace_id], aggr=[max(traces_largeutf8.timestamp)], lim=[4] +04)------CoalesceBatchesExec: target_batch_size=8192 +05)--------RepartitionExec: partitioning=Hash([trace_id@0], 4), input_partitions=4 +06)----------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +07)------------AggregateExec: mode=Partial, gby=[trace_id@0 as trace_id], aggr=[max(traces_largeutf8.timestamp)], lim=[4] +08)--------------DataSourceExec: partitions=1, partition_sizes=[1] + + statement ok drop table traces;