From f1fefeb445a11e690de7fa85ff13411f3a66a3f3 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 28 Feb 2025 13:33:40 +0800 Subject: [PATCH 01/10] fix alias --- .../tests/dataframe/dataframe_functions.rs | 6 +- datafusion/core/tests/dataframe/mod.rs | 111 +++++++++++++++--- datafusion/functions-aggregate/src/count.rs | 9 +- .../sqllogictest/test_files/subquery.slt | 34 ++++++ 4 files changed, 139 insertions(+), 21 deletions(-) diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 28c0740ca76b..fec3ab786fce 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -1145,9 +1145,9 @@ async fn test_count_wildcard() -> Result<()> { .build() .unwrap(); - let expected = "Sort: count(Int64(1)) ASC NULLS LAST [count(Int64(1)):Int64]\ - \n Projection: count(Int64(1)) [count(Int64(1)):Int64]\ - \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1))]] [b:UInt32, count(Int64(1)):Int64]\ + let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; let formatted_plan = plan.display_indent_schema().to_string(); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index b134ec54b13d..65cf95765dc2 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -33,7 +33,7 @@ use arrow::datatypes::{ use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_functions_aggregate::count::{count_all, count_udaf}; +use datafusion_functions_aggregate::count::{count_all, count_all_column, count_udaf}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, count, count_distinct, max, median, min, sum, }; @@ -2455,7 +2455,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("select b,count(1) from t1 group by b order by count(1)") + .sql("select b, count(*) from t1 group by b order by count(*)") .await? .explain(false, false)? .collect() @@ -2469,9 +2469,52 @@ async fn test_count_wildcard_on_sort() -> Result<()> { .explain(false, false)? .collect() .await?; - //make sure sql plan same with df plan + + let expected_sql_result = "+---------------+------------------------------------------------------------------------------------------------------------+\ + \n| plan_type | plan |\ + \n+---------------+------------------------------------------------------------------------------------------------------------+\ + \n| logical_plan | Projection: t1.b, count(*) |\ + \n| | Sort: count(Int64(1)) AS count(*) AS count(*) ASC NULLS LAST |\ + \n| | Projection: t1.b, count(Int64(1)) AS count(*), count(Int64(1)) |\ + \n| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] |\ + \n| | TableScan: t1 projection=[b] |\ + \n| physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)] |\ + \n| | SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] |\ + \n| | SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] |\ + \n| | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] |\ + \n| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] |\ + \n| | CoalesceBatchesExec: target_batch_size=8192 |\ + \n| | RepartitionExec: partitioning=Hash([b@0], 12), input_partitions=12 |\ + \n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\ + \n| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] |\ + \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ + \n| | |\ + \n+---------------+------------------------------------------------------------------------------------------------------------+"; + assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + expected_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + + let expected_df_result = "+---------------+---------------------------------------------------------------------------------+\ + \n| plan_type | plan |\ + \n+---------------+---------------------------------------------------------------------------------+\ + \n| logical_plan | Sort: count(*) ASC NULLS LAST |\ + \n| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] |\ + \n| | TableScan: t1 projection=[b] |\ + \n| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] |\ + \n| | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] |\ + \n| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] |\ + \n| | CoalesceBatchesExec: target_batch_size=8192 |\ + \n| | RepartitionExec: partitioning=Hash([b@0], 12), input_partitions=12 |\ + \n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\ + \n| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] |\ + \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ + \n| | |\ + \n+---------------+---------------------------------------------------------------------------------+"; + + assert_eq!( + expected_df_result, pretty_format_batches(&df_results)?.to_string() ); Ok(()) @@ -2481,12 +2524,19 @@ async fn test_count_wildcard_on_sort() -> Result<()> { async fn test_count_wildcard_on_where_in() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(1) FROM t2)") + .sql("SELECT a, b FROM t1 WHERE a in (SELECT count(*) FROM t2)") .await? .explain(false, false)? .collect() .await?; + let expected_sql_result = "+---------------+------------------------------------------------------------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+------------------------------------------------------------------------------------------------------------------------+\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __correlated_sq_1 |\n| | Projection: count(Int64(1)) AS count(*) |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\n| | TableScan: t2 projection=[] |\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\n| | ProjectionExec: expr=[4 as count(*)] |\n| | PlaceholderRowExec |\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+------------------------------------------------------------------------------------------------------------------------+"; + + assert_eq!( + expected_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 // https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 // for compare difference between sql and df logical plan, we need to create a new SessionContext here @@ -2509,9 +2559,11 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { .collect() .await?; + let actual_df_result= "+---------------+------------------------------------------------------------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+------------------------------------------------------------------------------------------------------------------------+\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __correlated_sq_1 |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\n| | TableScan: t2 projection=[] |\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\n| | ProjectionExec: expr=[4 as count(*)] |\n| | PlaceholderRowExec |\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+------------------------------------------------------------------------------------------------------------------------+"; + // make sure sql plan same with df plan assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + actual_df_result, pretty_format_batches(&df_results)?.to_string() ); @@ -2522,11 +2574,19 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { async fn test_count_wildcard_on_where_exist() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(1) FROM t2)") + .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)") .await? .explain(false, false)? .collect() .await?; + + let actual_sql_result = "+---------------+---------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------+\n| logical_plan | LeftSemi Join: |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __correlated_sq_1 |\n| | Projection: |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\n| | TableScan: t2 projection=[] |\n| physical_plan | NestedLoopJoinExec: join_type=RightSemi |\n| | ProjectionExec: expr=[] |\n| | PlaceholderRowExec |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+---------------------------------------------------------+"; + + assert_eq!( + actual_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + let df_results = ctx .table("t1") .await? @@ -2545,9 +2605,10 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { .collect() .await?; - //make sure sql plan same with df plan + let actual_df_result = "+---------------+---------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------------------+\n| logical_plan | LeftSemi Join: |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __correlated_sq_1 |\n| | Projection: |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\n| | TableScan: t2 projection=[] |\n| physical_plan | NestedLoopJoinExec: join_type=RightSemi |\n| | ProjectionExec: expr=[] |\n| | PlaceholderRowExec |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+---------------------------------------------------------------------+"; + assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + actual_df_result, pretty_format_batches(&df_results)?.to_string() ); @@ -2598,12 +2659,18 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { register_alltypes_tiny_pages_parquet(&ctx).await?; let sql_results = ctx - .sql("select count(1) from t1") + .sql("select count(*) from t1") .await? .explain(false, false)? .collect() .await?; + let actual_sql_result = "+---------------+-----------------------------------------------------+\n| plan_type | plan |\n+---------------+-----------------------------------------------------+\n| logical_plan | Projection: count(Int64(1)) AS count(*) |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\n| | TableScan: t1 projection=[] |\n| physical_plan | ProjectionExec: expr=[4 as count(*)] |\n| | PlaceholderRowExec |\n| | |\n+---------------+-----------------------------------------------------+"; + assert_eq!( + actual_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + // add `.select(vec![count_wildcard()])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("t1") @@ -2614,26 +2681,38 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { .collect() .await?; - //make sure sql plan same with df plan + let actual_df_result = "+---------------+---------------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------------+\n| logical_plan | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\n| | TableScan: t1 projection=[] |\n| physical_plan | ProjectionExec: expr=[4 as count(*)] |\n| | PlaceholderRowExec |\n| | |\n+---------------+---------------------------------------------------------------+"; assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + actual_df_result, pretty_format_batches(&df_results)?.to_string() ); Ok(()) } +#[tokio::test] +async fn test_count_wildcard_shema_name() { + assert_eq!(count_all().schema_name().to_string(), "count(*)"); + assert_eq!(count_all_column(), col("count(*)")); +} + #[tokio::test] async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("select a,b from t1 where (select count(1) from t2 where t1.a = t2.a)>0;") + .sql("select a,b from t1 where (select count(*) from t2 where t1.a = t2.a)>0;") .await? .explain(false, false)? .collect() .await?; + let actual_sql_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\n| logical_plan | Projection: t1.a, t1.b |\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |\n| | Left Join: t1.a = __scalar_sq_1.a |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __scalar_sq_1 |\n| | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true |\n| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] |\n| | TableScan: t2 projection=[a] |\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=1 |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] |\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=12 |\n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+---------------------------------------------------------------------------------------------------------------------------+"; + assert_eq!( + actual_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 // https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 // for compare difference between sql and df logical plan, we need to create a new SessionContext here @@ -2647,7 +2726,7 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { .await? .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? .aggregate(vec![], vec![count_all()])? - .select(vec![col(count_all().to_string())])? + .select(vec![count_all_column()])? .into_unoptimized_plan(), )) .gt(lit(ScalarValue::UInt8(Some(0)))), @@ -2657,9 +2736,9 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { .collect() .await?; - //make sure sql plan same with df plan + let actual_df_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\n| logical_plan | Projection: t1.a, t1.b |\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |\n| | Left Join: t1.a = __scalar_sq_1.a |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __scalar_sq_1 |\n| | Projection: count(*), t2.a, Boolean(true) AS __always_true |\n| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] |\n| | TableScan: t2 projection=[a] |\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=1 |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=12 |\n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+---------------------------------------------------------------------------------------------------------------------------+"; assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + actual_df_result, pretty_format_batches(&df_results)?.to_string() ); diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index a3339f0fceb9..17426b9323d0 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -47,11 +47,11 @@ use datafusion_common::{ downcast_value, internal_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{col, Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility, }; -use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; use datafusion_functions_aggregate_common::aggregate::count_distinct::{ BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, @@ -81,7 +81,12 @@ pub fn count_distinct(expr: Expr) -> Expr { /// Creates aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)` pub fn count_all() -> Expr { - count(Expr::Literal(COUNT_STAR_EXPANSION)) + count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)") +} + +/// Create count wildcard of Expr::Column +pub fn count_all_column() -> Expr { + col(count_all().schema_name().to_string()) } #[user_doc( diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 94c9eaf810fb..030e215d1cb3 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1393,3 +1393,37 @@ item1 1970-01-01T00:00:03 75 statement ok drop table source_table; + +# test count wildcard +statement count 0 +create table t1(a int) as values (1); + +statement count 0 +create table t2(b int) as values (1); + +query I +SELECT a FROM t1 WHERE EXISTS (SELECT count(*) FROM t2) +---- +1 + +query TT +explain SELECT a FROM t1 WHERE EXISTS (SELECT count(*) FROM t2) +---- +logical_plan +01)LeftSemi Join: +02)--TableScan: t1 projection=[a] +03)--SubqueryAlias: __correlated_sq_1 +04)----Projection: +05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +06)--------TableScan: t2 projection=[] +physical_plan +01)NestedLoopJoinExec: join_type=RightSemi +02)--ProjectionExec: expr=[] +03)----PlaceholderRowExec +04)--DataSourceExec: partitions=1, partition_sizes=[1] + +query +drop table t1; + +query +drop table t2; \ No newline at end of file From de527e133136de4948b3854e9d7843dd6858568d Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 28 Feb 2025 13:43:27 +0800 Subject: [PATCH 02/10] append the string --- datafusion/core/tests/dataframe/mod.rs | 152 +++++++++++++++++++++++-- 1 file changed, 144 insertions(+), 8 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 65cf95765dc2..bb3907280235 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -2470,6 +2470,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { .collect() .await?; + // TODO: Remove duplicated alias in Sort let expected_sql_result = "+---------------+------------------------------------------------------------------------------------------------------------+\ \n| plan_type | plan |\ \n+---------------+------------------------------------------------------------------------------------------------------------+\ @@ -2530,7 +2531,23 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { .collect() .await?; - let expected_sql_result = "+---------------+------------------------------------------------------------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+------------------------------------------------------------------------------------------------------------------------+\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __correlated_sq_1 |\n| | Projection: count(Int64(1)) AS count(*) |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\n| | TableScan: t2 projection=[] |\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\n| | ProjectionExec: expr=[4 as count(*)] |\n| | PlaceholderRowExec |\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+------------------------------------------------------------------------------------------------------------------------+"; + let expected_sql_result = "+---------------+------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) |\ +\n| | TableScan: t1 projection=[a, b] |\ +\n| | SubqueryAlias: __correlated_sq_1 |\ +\n| | Projection: count(Int64(1)) AS count(*) |\ +\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\ +\n| | TableScan: t2 projection=[] |\ +\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\ +\n| | ProjectionExec: expr=[4 as count(*)] |\ +\n| | PlaceholderRowExec |\ +\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+------------------------------------------------------------------------------------------------------------------------+"; assert_eq!( expected_sql_result, @@ -2559,7 +2576,22 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { .collect() .await?; - let actual_df_result= "+---------------+------------------------------------------------------------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+------------------------------------------------------------------------------------------------------------------------+\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __correlated_sq_1 |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\n| | TableScan: t2 projection=[] |\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\n| | ProjectionExec: expr=[4 as count(*)] |\n| | PlaceholderRowExec |\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+------------------------------------------------------------------------------------------------------------------------+"; + let actual_df_result= "+---------------+------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) |\ +\n| | TableScan: t1 projection=[a, b] |\ +\n| | SubqueryAlias: __correlated_sq_1 |\ +\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\ +\n| | TableScan: t2 projection=[] |\ +\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\ +\n| | ProjectionExec: expr=[4 as count(*)] |\ +\n| | PlaceholderRowExec |\ +\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+------------------------------------------------------------------------------------------------------------------------+"; // make sure sql plan same with df plan assert_eq!( @@ -2580,7 +2612,25 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { .collect() .await?; - let actual_sql_result = "+---------------+---------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------+\n| logical_plan | LeftSemi Join: |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __correlated_sq_1 |\n| | Projection: |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\n| | TableScan: t2 projection=[] |\n| physical_plan | NestedLoopJoinExec: join_type=RightSemi |\n| | ProjectionExec: expr=[] |\n| | PlaceholderRowExec |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+---------------------------------------------------------+"; + // TODO: + // 1) remove empty Projection + // 2) why count(*) alias is not shown + let actual_sql_result = + "+---------------+---------------------------------------------------------+\ + \n| plan_type | plan |\ + \n+---------------+---------------------------------------------------------+\ + \n| logical_plan | LeftSemi Join: |\ + \n| | TableScan: t1 projection=[a, b] |\ + \n| | SubqueryAlias: __correlated_sq_1 |\ + \n| | Projection: |\ + \n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\ + \n| | TableScan: t2 projection=[] |\ + \n| physical_plan | NestedLoopJoinExec: join_type=RightSemi |\ + \n| | ProjectionExec: expr=[] |\ + \n| | PlaceholderRowExec |\ + \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ + \n| | |\ + \n+---------------+---------------------------------------------------------+"; assert_eq!( actual_sql_result, @@ -2605,7 +2655,21 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { .collect() .await?; - let actual_df_result = "+---------------+---------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------------------+\n| logical_plan | LeftSemi Join: |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __correlated_sq_1 |\n| | Projection: |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\n| | TableScan: t2 projection=[] |\n| physical_plan | NestedLoopJoinExec: join_type=RightSemi |\n| | ProjectionExec: expr=[] |\n| | PlaceholderRowExec |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+---------------------------------------------------------------------+"; + let actual_df_result = "+---------------+---------------------------------------------------------------------+\ + \n| plan_type | plan |\ + \n+---------------+---------------------------------------------------------------------+\ + \n| logical_plan | LeftSemi Join: |\ + \n| | TableScan: t1 projection=[a, b] |\ + \n| | SubqueryAlias: __correlated_sq_1 |\ + \n| | Projection: |\ + \n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\ + \n| | TableScan: t2 projection=[] |\ + \n| physical_plan | NestedLoopJoinExec: join_type=RightSemi |\ + \n| | ProjectionExec: expr=[] |\ + \n| | PlaceholderRowExec |\ + \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ + \n| | |\ + \n+---------------+---------------------------------------------------------------------+"; assert_eq!( actual_df_result, @@ -2665,7 +2729,17 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { .collect() .await?; - let actual_sql_result = "+---------------+-----------------------------------------------------+\n| plan_type | plan |\n+---------------+-----------------------------------------------------+\n| logical_plan | Projection: count(Int64(1)) AS count(*) |\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\n| | TableScan: t1 projection=[] |\n| physical_plan | ProjectionExec: expr=[4 as count(*)] |\n| | PlaceholderRowExec |\n| | |\n+---------------+-----------------------------------------------------+"; + let actual_sql_result = + "+---------------+-----------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+-----------------------------------------------------+\ +\n| logical_plan | Projection: count(Int64(1)) AS count(*) |\ +\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\ +\n| | TableScan: t1 projection=[] |\ +\n| physical_plan | ProjectionExec: expr=[4 as count(*)] |\ +\n| | PlaceholderRowExec |\ +\n| | |\ +\n+---------------+-----------------------------------------------------+"; assert_eq!( actual_sql_result, pretty_format_batches(&sql_results)?.to_string() @@ -2681,7 +2755,15 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { .collect() .await?; - let actual_df_result = "+---------------+---------------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------------+\n| logical_plan | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\n| | TableScan: t1 projection=[] |\n| physical_plan | ProjectionExec: expr=[4 as count(*)] |\n| | PlaceholderRowExec |\n| | |\n+---------------+---------------------------------------------------------------+"; + let actual_df_result = "+---------------+---------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------+\ +\n| logical_plan | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\ +\n| | TableScan: t1 projection=[] |\ +\n| physical_plan | ProjectionExec: expr=[4 as count(*)] |\ +\n| | PlaceholderRowExec |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------+"; assert_eq!( actual_df_result, pretty_format_batches(&df_results)?.to_string() @@ -2707,7 +2789,34 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { .collect() .await?; - let actual_sql_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\n| logical_plan | Projection: t1.a, t1.b |\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |\n| | Left Join: t1.a = __scalar_sq_1.a |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __scalar_sq_1 |\n| | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true |\n| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] |\n| | TableScan: t2 projection=[a] |\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=1 |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] |\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=12 |\n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+---------------------------------------------------------------------------------------------------------------------------+"; + let actual_sql_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | Projection: t1.a, t1.b |\ +\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\ +\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |\ +\n| | Left Join: t1.a = __scalar_sq_1.a |\ +\n| | TableScan: t1 projection=[a, b] |\ +\n| | SubqueryAlias: __scalar_sq_1 |\ +\n| | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true |\ +\n| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] |\ +\n| | TableScan: t2 projection=[a] |\ +\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=1 |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] |\ +\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=12 |\ +\n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\ +\n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------+"; assert_eq!( actual_sql_result, pretty_format_batches(&sql_results)?.to_string() @@ -2736,7 +2845,34 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { .collect() .await?; - let actual_df_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------+\n| plan_type | plan |\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\n| logical_plan | Projection: t1.a, t1.b |\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |\n| | Left Join: t1.a = __scalar_sq_1.a |\n| | TableScan: t1 projection=[a, b] |\n| | SubqueryAlias: __scalar_sq_1 |\n| | Projection: count(*), t2.a, Boolean(true) AS __always_true |\n| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] |\n| | TableScan: t2 projection=[a] |\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=1 |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |\n| | CoalesceBatchesExec: target_batch_size=8192 |\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=12 |\n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\n| | |\n+---------------+---------------------------------------------------------------------------------------------------------------------------+"; + let actual_df_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | Projection: t1.a, t1.b |\ +\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\ +\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |\ +\n| | Left Join: t1.a = __scalar_sq_1.a |\ +\n| | TableScan: t1 projection=[a, b] |\ +\n| | SubqueryAlias: __scalar_sq_1 |\ +\n| | Projection: count(*), t2.a, Boolean(true) AS __always_true |\ +\n| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] |\ +\n| | TableScan: t2 projection=[a] |\ +\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=1 |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |\ +\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=12 |\ +\n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\ +\n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------+"; assert_eq!( actual_df_result, pretty_format_batches(&df_results)?.to_string() From af1a6d300c32a87d6f8489959ed005ace6a254a0 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 28 Feb 2025 13:56:37 +0800 Subject: [PATCH 03/10] window count --- datafusion/core/tests/dataframe/mod.rs | 47 ++++++++++++++++----- datafusion/functions-aggregate/src/count.rs | 11 ++++- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index bb3907280235..80acb48c37f7 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -32,8 +32,7 @@ use arrow::datatypes::{ }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_functions_aggregate::count::{count_all, count_all_column, count_udaf}; +use datafusion_functions_aggregate::count::{count_all, count_all_column, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, count, count_distinct, max, median, min, sum, }; @@ -2684,18 +2683,34 @@ async fn test_count_wildcard_on_window() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("select count(1) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") + .sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") .await? .explain(false, false)? .collect() .await?; + + let actual_sql_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING |\ +\n| | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] |\ +\n| | TableScan: t1 projection=[a] |\ +\n| physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] |\ +\n| | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: \"count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: false }], mode=[Sorted] |\ +\n| | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+"; + + assert_eq!( + actual_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + let df_results = ctx .table("t1") .await? - .select(vec![Expr::WindowFunction(WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![Expr::Literal(COUNT_STAR_EXPANSION)], - )) + .select(vec![count_all_window() .order_by(vec![Sort::new(col("a"), false, true)]) .window_frame(WindowFrame::new_bounds( WindowFrameUnits::Range, @@ -2708,10 +2723,22 @@ async fn test_count_wildcard_on_window() -> Result<()> { .collect() .await?; - //make sure sql plan same with df plan + let actual_df_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING |\ +\n| | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] |\ +\n| | TableScan: t1 projection=[a] |\ +\n| physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] |\ +\n| | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: \"count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: false }], mode=[Sorted] |\ +\n| | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+"; + assert_eq!( - pretty_format_batches(&df_results)?.to_string(), - pretty_format_batches(&sql_results)?.to_string() + actual_df_result, + pretty_format_batches(&df_results)?.to_string() ); Ok(()) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 17426b9323d0..bf12406c5600 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -17,6 +17,7 @@ use ahash::RandomState; use datafusion_common::stats::Precision; +use datafusion_expr::expr::WindowFunction; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions; @@ -47,7 +48,7 @@ use datafusion_common::{ downcast_value, internal_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; -use datafusion_expr::{col, Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; +use datafusion_expr::{col, Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition}; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility, @@ -84,6 +85,14 @@ pub fn count_all() -> Expr { count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)") } +/// Creates window aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)` +pub fn count_all_window() -> Expr { + Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + vec![Expr::Literal(COUNT_STAR_EXPANSION)], + )) +} + /// Create count wildcard of Expr::Column pub fn count_all_column() -> Expr { col(count_all().schema_name().to_string()) From dff35903d2efb347980ac790c3838a5649c8d8d5 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 28 Feb 2025 13:59:41 +0800 Subject: [PATCH 04/10] add column --- datafusion/core/tests/dataframe/mod.rs | 4 +++- datafusion/functions-aggregate/src/count.rs | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 80acb48c37f7..28c330775fc3 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -32,7 +32,7 @@ use arrow::datatypes::{ }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; -use datafusion_functions_aggregate::count::{count_all, count_all_column, count_all_window}; +use datafusion_functions_aggregate::count::{count_all, count_all_column, count_all_window, count_all_window_column}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, count, count_distinct, max, median, min, sum, }; @@ -2803,6 +2803,8 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { async fn test_count_wildcard_shema_name() { assert_eq!(count_all().schema_name().to_string(), "count(*)"); assert_eq!(count_all_column(), col("count(*)")); + assert_eq!(count_all_window_column(), col("count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING")); + } #[tokio::test] diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index bf12406c5600..f79ce421cbe2 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -81,6 +81,7 @@ pub fn count_distinct(expr: Expr) -> Expr { } /// Creates aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)` +/// Alias to count(*) for backward comaptibility pub fn count_all() -> Expr { count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)") } @@ -93,6 +94,14 @@ pub fn count_all_window() -> Expr { )) } +/// Create count wildcard window func of Expr::Column +pub fn count_all_window_column() -> Expr { + col(Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + vec![Expr::Literal(COUNT_STAR_EXPANSION)], + )).schema_name().to_string()) +} + /// Create count wildcard of Expr::Column pub fn count_all_column() -> Expr { col(count_all().schema_name().to_string()) From 1a07b82872875adfedf32607c2413b227fe93bae Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 28 Feb 2025 14:04:28 +0800 Subject: [PATCH 05/10] fmt --- datafusion/core/tests/dataframe/mod.rs | 26 ++++++++++++--------- datafusion/functions-aggregate/src/count.rs | 8 +++++-- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 28c330775fc3..894877a98ea7 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -32,7 +32,9 @@ use arrow::datatypes::{ }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; -use datafusion_functions_aggregate::count::{count_all, count_all_column, count_all_window, count_all_window_column}; +use datafusion_functions_aggregate::count::{ + count_all, count_all_column, count_all_window, count_all_window_column, +}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, count, count_distinct, max, median, min, sum, }; @@ -2711,14 +2713,14 @@ async fn test_count_wildcard_on_window() -> Result<()> { .table("t1") .await? .select(vec![count_all_window() - .order_by(vec![Sort::new(col("a"), false, true)]) - .window_frame(WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )) - .build() - .unwrap()])? + .order_by(vec![Sort::new(col("a"), false, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap()])? .explain(false, false)? .collect() .await?; @@ -2803,8 +2805,10 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { async fn test_count_wildcard_shema_name() { assert_eq!(count_all().schema_name().to_string(), "count(*)"); assert_eq!(count_all_column(), col("count(*)")); - assert_eq!(count_all_window_column(), col("count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING")); - + assert_eq!( + count_all_window_column(), + col("count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING") + ); } #[tokio::test] diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index f79ce421cbe2..bd66dbb42771 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -48,7 +48,9 @@ use datafusion_common::{ downcast_value, internal_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; -use datafusion_expr::{col, Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition}; +use datafusion_expr::{ + col, Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition, +}; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility, @@ -99,7 +101,9 @@ pub fn count_all_window_column() -> Expr { col(Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), vec![Expr::Literal(COUNT_STAR_EXPANSION)], - )).schema_name().to_string()) + )) + .schema_name() + .to_string()) } /// Create count wildcard of Expr::Column From 52fa4b0e9c6998ee5e6724dd90198b4f1c84ebd6 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Fri, 28 Feb 2025 14:06:17 +0800 Subject: [PATCH 06/10] rm todo --- datafusion/core/tests/dataframe/mod.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 894877a98ea7..f76bb6ee6e55 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -2471,7 +2471,6 @@ async fn test_count_wildcard_on_sort() -> Result<()> { .collect() .await?; - // TODO: Remove duplicated alias in Sort let expected_sql_result = "+---------------+------------------------------------------------------------------------------------------------------------+\ \n| plan_type | plan |\ \n+---------------+------------------------------------------------------------------------------------------------------------+\ @@ -2613,9 +2612,6 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { .collect() .await?; - // TODO: - // 1) remove empty Projection - // 2) why count(*) alias is not shown let actual_sql_result = "+---------------+---------------------------------------------------------+\ \n| plan_type | plan |\ From a945a6fdbb7b54e0168505b2d67a6346d31a30af Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 1 Mar 2025 09:26:32 +0800 Subject: [PATCH 07/10] fixed partitioned --- datafusion/core/tests/dataframe/mod.rs | 52 +++++++++++++------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index f76bb6ee6e55..3010144224d2 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -2485,8 +2485,8 @@ async fn test_count_wildcard_on_sort() -> Result<()> { \n| | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] |\ \n| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] |\ \n| | CoalesceBatchesExec: target_batch_size=8192 |\ - \n| | RepartitionExec: partitioning=Hash([b@0], 12), input_partitions=12 |\ - \n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\ + \n| | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 |\ + \n| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |\ \n| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] |\ \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ \n| | |\ @@ -2497,22 +2497,22 @@ async fn test_count_wildcard_on_sort() -> Result<()> { pretty_format_batches(&sql_results)?.to_string() ); - let expected_df_result = "+---------------+---------------------------------------------------------------------------------+\ - \n| plan_type | plan |\ - \n+---------------+---------------------------------------------------------------------------------+\ - \n| logical_plan | Sort: count(*) ASC NULLS LAST |\ - \n| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] |\ - \n| | TableScan: t1 projection=[b] |\ - \n| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] |\ - \n| | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] |\ - \n| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] |\ - \n| | CoalesceBatchesExec: target_batch_size=8192 |\ - \n| | RepartitionExec: partitioning=Hash([b@0], 12), input_partitions=12 |\ - \n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\ - \n| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] |\ - \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ - \n| | |\ - \n+---------------+---------------------------------------------------------------------------------+"; + let expected_df_result = "+---------------+--------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+--------------------------------------------------------------------------------+\ +\n| logical_plan | Sort: count(*) ASC NULLS LAST |\ +\n| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] |\ +\n| | TableScan: t1 projection=[b] |\ +\n| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] |\ +\n| | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] |\ +\n| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 |\ +\n| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |\ +\n| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+--------------------------------------------------------------------------------+"; assert_eq!( expected_df_result, @@ -2835,13 +2835,13 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { \n| | CoalesceBatchesExec: target_batch_size=8192 |\ \n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\ \n| | CoalesceBatchesExec: target_batch_size=8192 |\ -\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=1 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |\ \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ \n| | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] |\ \n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] |\ \n| | CoalesceBatchesExec: target_batch_size=8192 |\ -\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=12 |\ -\n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 |\ +\n| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |\ \n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] |\ \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ \n| | |\ @@ -2891,13 +2891,13 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { \n| | CoalesceBatchesExec: target_batch_size=8192 |\ \n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\ \n| | CoalesceBatchesExec: target_batch_size=8192 |\ -\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=1 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |\ \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ \n| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |\ \n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |\ \n| | CoalesceBatchesExec: target_batch_size=8192 |\ -\n| | RepartitionExec: partitioning=Hash([a@0], 12), input_partitions=12 |\ -\n| | RepartitionExec: partitioning=RoundRobinBatch(12), input_partitions=1 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 |\ +\n| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |\ \n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |\ \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ \n| | |\ @@ -4472,7 +4472,9 @@ fn create_join_context() -> Result { ], )?; - let ctx = SessionContext::new(); + let config = SessionConfig::new().with_target_partitions(4); + let ctx = SessionContext::new_with_config(config); + // let ctx = SessionContext::new(); ctx.register_batch("t1", batch1)?; ctx.register_batch("t2", batch2)?; From e150a317ee0c30b94f46bee45296d981753a363b Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 1 Mar 2025 10:26:15 +0800 Subject: [PATCH 08/10] fix test --- .../sqllogictest/test_files/subquery.slt | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 030e215d1cb3..207bb72fd549 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1394,6 +1394,15 @@ item1 1970-01-01T00:00:03 75 statement ok drop table source_table; +statement count 0 +drop table t1; + +statement count 0 +drop table t2; + +statement count 0 +drop table t3; + # test count wildcard statement count 0 create table t1(a int) as values (1); @@ -1416,14 +1425,9 @@ logical_plan 04)----Projection: 05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] 06)--------TableScan: t2 projection=[] -physical_plan -01)NestedLoopJoinExec: join_type=RightSemi -02)--ProjectionExec: expr=[] -03)----PlaceholderRowExec -04)--DataSourceExec: partitions=1, partition_sizes=[1] -query +statement count 0 drop table t1; -query -drop table t2; \ No newline at end of file +statement count 0 +drop table t2; From d19fafa73af61e2c3703a7099a0ae0bfdc54c164 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 4 Mar 2025 08:01:43 +0800 Subject: [PATCH 09/10] update doc --- datafusion/functions-aggregate/src/count.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index bd66dbb42771..5afe2b0584eb 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -96,7 +96,8 @@ pub fn count_all_window() -> Expr { )) } -/// Create count wildcard window func of Expr::Column +/// Expr::Column(Count Wildcard Window Function) +/// Could be used in Dataframe API where you need Expr::Column of count wildcard pub fn count_all_window_column() -> Expr { col(Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), @@ -106,7 +107,8 @@ pub fn count_all_window_column() -> Expr { .to_string()) } -/// Create count wildcard of Expr::Column +/// Expr::Column(Count Wildcard Aggregate Function) +/// Could be used in Dataframe API where you need Expr::Column of count wildcard pub fn count_all_column() -> Expr { col(count_all().schema_name().to_string()) } From 595750d4124f4140a35e6e42d345034253ca7901 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Tue, 4 Mar 2025 16:50:31 -0500 Subject: [PATCH 10/10] Suggestion to reduce API surface area --- datafusion/core/tests/dataframe/mod.rs | 20 ++----- datafusion/functions-aggregate/src/count.rs | 62 +++++++++++++-------- 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 3010144224d2..107e09bd81a9 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -32,9 +32,7 @@ use arrow::datatypes::{ }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; -use datafusion_functions_aggregate::count::{ - count_all, count_all_column, count_all_window, count_all_window_column, -}; +use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, count, count_distinct, max, median, min, sum, }; @@ -2797,16 +2795,6 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { Ok(()) } -#[tokio::test] -async fn test_count_wildcard_shema_name() { - assert_eq!(count_all().schema_name().to_string(), "count(*)"); - assert_eq!(count_all_column(), col("count(*)")); - assert_eq!( - count_all_window_column(), - col("count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING") - ); -} - #[tokio::test] async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { let ctx = create_join_context()?; @@ -2855,6 +2843,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { // https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 // for compare difference between sql and df logical plan, we need to create a new SessionContext here let ctx = create_join_context()?; + let agg_expr = count_all(); + let agg_expr_col = col(agg_expr.schema_name().to_string()); let df_results = ctx .table("t1") .await? @@ -2863,8 +2853,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { ctx.table("t2") .await? .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? - .aggregate(vec![], vec![count_all()])? - .select(vec![count_all_column()])? + .aggregate(vec![], vec![agg_expr])? + .select(vec![agg_expr_col])? .into_unoptimized_plan(), )) .gt(lit(ScalarValue::UInt8(Some(0)))), diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 5afe2b0584eb..2d995b4a4179 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -48,13 +48,13 @@ use datafusion_common::{ downcast_value, internal_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; -use datafusion_expr::{ - col, Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition, -}; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility, }; +use datafusion_expr::{ + Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition, +}; use datafusion_functions_aggregate_common::aggregate::count_distinct::{ BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, @@ -82,13 +82,46 @@ pub fn count_distinct(expr: Expr) -> Expr { )) } -/// Creates aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)` -/// Alias to count(*) for backward comaptibility +/// Creates aggregation to count all rows. +/// +/// In SQL this is `SELECT COUNT(*) ... ` +/// +/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is +/// aliased to a column named `"count(*)"` for backward compatibility. +/// +/// Example +/// ``` +/// # use datafusion_functions_aggregate::count::count_all; +/// # use datafusion_expr::col; +/// // create `count(*)` expression +/// let expr = count_all(); +/// assert_eq!(expr.schema_name().to_string(), "count(*)"); +/// // if you need to refer to this column, use the `schema_name` function +/// let expr = col(expr.schema_name().to_string()); +/// ``` pub fn count_all() -> Expr { count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)") } -/// Creates window aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)` +/// Creates window aggregation to count all rows. +/// +/// In SQL this is `SELECT COUNT(*) OVER (..) ... ` +/// +/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)` +/// +/// Example +/// ``` +/// # use datafusion_functions_aggregate::count::count_all_window; +/// # use datafusion_expr::col; +/// // create `count(*)` OVER ... window function expression +/// let expr = count_all_window(); +/// assert_eq!( +/// expr.schema_name().to_string(), +/// "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" +/// ); +/// // if you need to refer to this column, use the `schema_name` function +/// let expr = col(expr.schema_name().to_string()); +/// ``` pub fn count_all_window() -> Expr { Expr::WindowFunction(WindowFunction::new( WindowFunctionDefinition::AggregateUDF(count_udaf()), @@ -96,23 +129,6 @@ pub fn count_all_window() -> Expr { )) } -/// Expr::Column(Count Wildcard Window Function) -/// Could be used in Dataframe API where you need Expr::Column of count wildcard -pub fn count_all_window_column() -> Expr { - col(Expr::WindowFunction(WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![Expr::Literal(COUNT_STAR_EXPANSION)], - )) - .schema_name() - .to_string()) -} - -/// Expr::Column(Count Wildcard Aggregate Function) -/// Could be used in Dataframe API where you need Expr::Column of count wildcard -pub fn count_all_column() -> Expr { - col(count_all().schema_name().to_string()) -} - #[user_doc( doc_section(label = "General Functions"), description = "Returns the number of non-null values in the specified column. To include null values in the total count, use `count(*)`.",