Skip to content

Commit f4c0edb

Browse files
authored
fix erro on Count(Expr:Wildcard) with DataFrame API (#5627)
1 parent a73d4fc commit f4c0edb

File tree

5 files changed

+162
-4
lines changed

5 files changed

+162
-4
lines changed

datafusion/core/tests/dataframe.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use arrow::datatypes::{DataType, Field, Schema};
19+
use arrow::util::pretty::pretty_format_batches;
1920
use arrow::{
2021
array::{
2122
ArrayRef, Int32Array, Int32Builder, ListBuilder, StringArray, StringBuilder,
@@ -35,6 +36,58 @@ use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
3536
use datafusion_expr::expr::{GroupingSet, Sort};
3637
use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
3738

39+
#[tokio::test]
40+
async fn count_wildcard() -> Result<()> {
41+
let ctx = SessionContext::new();
42+
let testdata = datafusion::test_util::parquet_test_data();
43+
44+
ctx.register_parquet(
45+
"alltypes_tiny_pages",
46+
&format!("{testdata}/alltypes_tiny_pages.parquet"),
47+
ParquetReadOptions::default(),
48+
)
49+
.await?;
50+
51+
let sql_results = ctx
52+
.sql("select count(*) from alltypes_tiny_pages")
53+
.await?
54+
.explain(false, false)?
55+
.collect()
56+
.await?;
57+
58+
let df_results = ctx
59+
.table("alltypes_tiny_pages")
60+
.await?
61+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
62+
.explain(false, false)
63+
.unwrap()
64+
.collect()
65+
.await?;
66+
67+
//make sure sql plan same with df plan
68+
assert_eq!(
69+
pretty_format_batches(&sql_results)?.to_string(),
70+
pretty_format_batches(&df_results)?.to_string()
71+
);
72+
73+
let results = ctx
74+
.table("alltypes_tiny_pages")
75+
.await?
76+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
77+
.collect()
78+
.await?;
79+
80+
let expected = vec![
81+
"+-----------------+",
82+
"| COUNT(UInt8(1)) |",
83+
"+-----------------+",
84+
"| 7300 |",
85+
"+-----------------+",
86+
];
87+
assert_batches_sorted_eq!(expected, &results);
88+
89+
Ok(())
90+
}
3891
#[tokio::test]
3992
async fn describe() -> Result<()> {
4093
let ctx = SessionContext::new();

datafusion/expr/src/expr_schema.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ impl ExprSchemable for Expr {
136136
Expr::Placeholder { data_type, .. } => data_type.clone().ok_or_else(|| {
137137
DataFusionError::Plan("Placeholder type could not be resolved".to_owned())
138138
}),
139-
Expr::Wildcard => Err(DataFusionError::Internal(
140-
"Wildcard expressions are not valid in a logical query plan".to_owned(),
141-
)),
139+
Expr::Wildcard => {
140+
// Wildcard do not really have a type and do not appear in projections
141+
Ok(DataType::Null)
142+
}
142143
Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal(
143144
"QualifiedWildcard expressions are not valid in a logical query plan"
144145
.to_owned(),

datafusion/optimizer/src/analyzer.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::count_wildcard_rule::CountWildcardRule;
1819
use crate::rewrite::TreeNodeRewritable;
1920
use datafusion_common::config::ConfigOptions;
2021
use datafusion_common::{DataFusionError, Result};
@@ -49,7 +50,8 @@ impl Default for Analyzer {
4950
impl Analyzer {
5051
/// Create a new analyzer using the recommended list of rules
5152
pub fn new() -> Self {
52-
let rules = vec![];
53+
let rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>> =
54+
vec![Arc::new(CountWildcardRule::new())];
5355
Self::with_rules(rules)
5456
}
5557

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::analyzer::AnalyzerRule;
19+
use datafusion_common::config::ConfigOptions;
20+
use datafusion_common::Result;
21+
use datafusion_expr::expr::AggregateFunction;
22+
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
23+
use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window};
24+
use std::ops::Deref;
25+
use std::sync::Arc;
26+
27+
pub struct CountWildcardRule {}
28+
29+
impl Default for CountWildcardRule {
30+
fn default() -> Self {
31+
CountWildcardRule::new()
32+
}
33+
}
34+
35+
impl CountWildcardRule {
36+
pub fn new() -> Self {
37+
CountWildcardRule {}
38+
}
39+
}
40+
impl AnalyzerRule for CountWildcardRule {
41+
fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> Result<LogicalPlan> {
42+
let new_plan = match plan {
43+
LogicalPlan::Window(window) => {
44+
let inputs = plan.inputs();
45+
let window_expr = window.clone().window_expr;
46+
let window_expr = handle_wildcard(window_expr).unwrap();
47+
LogicalPlan::Window(Window {
48+
input: Arc::new(inputs.get(0).unwrap().deref().clone()),
49+
window_expr,
50+
schema: plan.schema().clone(),
51+
})
52+
}
53+
54+
LogicalPlan::Aggregate(aggregate) => {
55+
let inputs = plan.inputs();
56+
let aggr_expr = aggregate.clone().aggr_expr;
57+
let aggr_expr = handle_wildcard(aggr_expr).unwrap();
58+
LogicalPlan::Aggregate(
59+
Aggregate::try_new_with_schema(
60+
Arc::new(inputs.get(0).unwrap().deref().clone()),
61+
aggregate.clone().group_expr,
62+
aggr_expr,
63+
plan.schema().clone(),
64+
)
65+
.unwrap(),
66+
)
67+
}
68+
_ => plan.clone(),
69+
};
70+
Ok(new_plan)
71+
}
72+
73+
fn name(&self) -> &str {
74+
"count_wildcard_rule"
75+
}
76+
}
77+
78+
//handle Count(Expr:Wildcard) with DataFrame API
79+
pub fn handle_wildcard(exprs: Vec<Expr>) -> Result<Vec<Expr>> {
80+
let exprs: Vec<Expr> = exprs
81+
.iter()
82+
.map(|expr| match expr {
83+
Expr::AggregateFunction(AggregateFunction {
84+
fun: aggregate_function::AggregateFunction::Count,
85+
args,
86+
distinct,
87+
filter,
88+
}) if args.len() == 1 => match args[0] {
89+
Expr::Wildcard => Expr::AggregateFunction(AggregateFunction {
90+
fun: aggregate_function::AggregateFunction::Count,
91+
args: vec![lit(COUNT_STAR_EXPANSION)],
92+
distinct: *distinct,
93+
filter: filter.clone(),
94+
}),
95+
_ => expr.clone(),
96+
},
97+
_ => expr.clone(),
98+
})
99+
.collect();
100+
Ok(exprs)
101+
}

datafusion/optimizer/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub mod type_coercion;
4545
pub mod unwrap_cast_in_comparison;
4646
pub mod utils;
4747

48+
pub mod count_wildcard_rule;
4849
#[cfg(test)]
4950
pub mod test;
5051

0 commit comments

Comments
 (0)