diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 82bd0d8443f9..031dd72d4627 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -16,6 +16,7 @@ // under the License. use arrow::datatypes::{DataType, Field, Schema}; +use arrow::util::pretty::pretty_format_batches; use arrow::{ array::{ ArrayRef, Int32Array, Int32Builder, ListBuilder, StringArray, StringBuilder, @@ -35,6 +36,58 @@ use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable}; +#[tokio::test] +async fn count_wildcard() -> Result<()> { + let ctx = SessionContext::new(); + let testdata = datafusion::test_util::parquet_test_data(); + + ctx.register_parquet( + "alltypes_tiny_pages", + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default(), + ) + .await?; + + let sql_results = ctx + .sql("select count(*) from alltypes_tiny_pages") + .await? + .explain(false, false)? + .collect() + .await?; + + let df_results = ctx + .table("alltypes_tiny_pages") + .await? + .aggregate(vec![], vec![count(Expr::Wildcard)])? + .explain(false, false) + .unwrap() + .collect() + .await?; + + //make sure sql plan same with df plan + assert_eq!( + pretty_format_batches(&sql_results)?.to_string(), + pretty_format_batches(&df_results)?.to_string() + ); + + let results = ctx + .table("alltypes_tiny_pages") + .await? + .aggregate(vec![], vec![count(Expr::Wildcard)])? + .collect() + .await?; + + let expected = vec![ + "+-----------------+", + "| COUNT(UInt8(1)) |", + "+-----------------+", + "| 7300 |", + "+-----------------+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} #[tokio::test] async fn describe() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 5ceaf1668ada..32e831b43056 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -136,9 +136,10 @@ impl ExprSchemable for Expr { Expr::Placeholder { data_type, .. } => data_type.clone().ok_or_else(|| { DataFusionError::Plan("Placeholder type could not be resolved".to_owned()) }), - Expr::Wildcard => Err(DataFusionError::Internal( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), + Expr::Wildcard => { + // Wildcard do not really have a type and do not appear in projections + Ok(DataType::Null) + } Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal( "QualifiedWildcard expressions are not valid in a logical query plan" .to_owned(), diff --git a/datafusion/optimizer/src/analyzer.rs b/datafusion/optimizer/src/analyzer.rs index f2a1ba9d64bb..e999eb2419d0 100644 --- a/datafusion/optimizer/src/analyzer.rs +++ b/datafusion/optimizer/src/analyzer.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::count_wildcard_rule::CountWildcardRule; use crate::rewrite::TreeNodeRewritable; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; @@ -49,7 +50,8 @@ impl Default for Analyzer { impl Analyzer { /// Create a new analyzer using the recommended list of rules pub fn new() -> Self { - let rules = vec![]; + let rules: Vec> = + vec![Arc::new(CountWildcardRule::new())]; Self::with_rules(rules) } diff --git a/datafusion/optimizer/src/count_wildcard_rule.rs b/datafusion/optimizer/src/count_wildcard_rule.rs new file mode 100644 index 000000000000..416bd0337a4d --- /dev/null +++ b/datafusion/optimizer/src/count_wildcard_rule.rs @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::analyzer::AnalyzerRule; +use datafusion_common::config::ConfigOptions; +use datafusion_common::Result; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window}; +use std::ops::Deref; +use std::sync::Arc; + +pub struct CountWildcardRule {} + +impl Default for CountWildcardRule { + fn default() -> Self { + CountWildcardRule::new() + } +} + +impl CountWildcardRule { + pub fn new() -> Self { + CountWildcardRule {} + } +} +impl AnalyzerRule for CountWildcardRule { + fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> Result { + let new_plan = match plan { + LogicalPlan::Window(window) => { + let inputs = plan.inputs(); + let window_expr = window.clone().window_expr; + let window_expr = handle_wildcard(window_expr).unwrap(); + LogicalPlan::Window(Window { + input: Arc::new(inputs.get(0).unwrap().deref().clone()), + window_expr, + schema: plan.schema().clone(), + }) + } + + LogicalPlan::Aggregate(aggregate) => { + let inputs = plan.inputs(); + let aggr_expr = aggregate.clone().aggr_expr; + let aggr_expr = handle_wildcard(aggr_expr).unwrap(); + LogicalPlan::Aggregate( + Aggregate::try_new_with_schema( + Arc::new(inputs.get(0).unwrap().deref().clone()), + aggregate.clone().group_expr, + aggr_expr, + plan.schema().clone(), + ) + .unwrap(), + ) + } + _ => plan.clone(), + }; + Ok(new_plan) + } + + fn name(&self) -> &str { + "count_wildcard_rule" + } +} + +//handle Count(Expr:Wildcard) with DataFrame API +pub fn handle_wildcard(exprs: Vec) -> Result> { + let exprs: Vec = exprs + .iter() + .map(|expr| match expr { + Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Count, + args, + distinct, + filter, + }) if args.len() == 1 => match args[0] { + Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { + fun: aggregate_function::AggregateFunction::Count, + args: vec![lit(COUNT_STAR_EXPANSION)], + distinct: *distinct, + filter: filter.clone(), + }), + _ => expr.clone(), + }, + _ => expr.clone(), + }) + .collect(); + Ok(exprs) +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 7f930ae3a8d0..3fa1995271dc 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -45,6 +45,7 @@ pub mod type_coercion; pub mod unwrap_cast_in_comparison; pub mod utils; +pub mod count_wildcard_rule; #[cfg(test)] pub mod test;