Skip to content

Commit 9dd162f

Browse files
committed
fix erro on Count(Expr:Wildcard) with DataFrame API
1 parent 36fe974 commit 9dd162f

File tree

6 files changed

+134
-5
lines changed

6 files changed

+134
-5
lines changed

datafusion/core/tests/dataframe.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use arrow::datatypes::{DataType, Field, Schema};
19+
use arrow::util::pretty::pretty_format_batches;
1920
use arrow::{
2021
array::{
2122
ArrayRef, Int32Array, Int32Builder, ListBuilder, StringArray, StringBuilder,
@@ -35,6 +36,58 @@ use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
3536
use datafusion_expr::expr::{GroupingSet, Sort};
3637
use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable};
3738

39+
#[tokio::test]
40+
async fn count_wildcard() -> Result<()> {
41+
let ctx = SessionContext::new();
42+
let testdata = datafusion::test_util::parquet_test_data();
43+
44+
ctx.register_parquet(
45+
"alltypes_tiny_pages",
46+
&format!("{testdata}/alltypes_tiny_pages.parquet"),
47+
ParquetReadOptions::default(),
48+
)
49+
.await?;
50+
51+
let sql_results = ctx
52+
.sql("select count(*) from alltypes_tiny_pages")
53+
.await?
54+
.explain(false, false)?
55+
.collect()
56+
.await?;
57+
58+
let df_results = ctx
59+
.table("alltypes_tiny_pages")
60+
.await?
61+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
62+
.explain(false, false)
63+
.unwrap()
64+
.collect()
65+
.await?;
66+
67+
//make sure sql plan same with df plan
68+
assert_eq!(
69+
pretty_format_batches(&sql_results)?.to_string(),
70+
pretty_format_batches(&df_results)?.to_string()
71+
);
72+
73+
let results = ctx
74+
.table("alltypes_tiny_pages")
75+
.await?
76+
.aggregate(vec![], vec![count(Expr::Wildcard)])?
77+
.collect()
78+
.await?;
79+
80+
let expected = vec![
81+
"+-----------------+",
82+
"| COUNT(UInt8(1)) |",
83+
"+-----------------+",
84+
"| 7300 |",
85+
"+-----------------+",
86+
];
87+
assert_batches_sorted_eq!(expected, &results);
88+
89+
Ok(())
90+
}
3891
#[tokio::test]
3992
async fn describe() -> Result<()> {
4093
let ctx = SessionContext::new();

datafusion/expr/src/expr_schema.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,10 @@ impl ExprSchemable for Expr {
136136
Expr::Placeholder { data_type, .. } => data_type.clone().ok_or_else(|| {
137137
DataFusionError::Plan("Placeholder type could not be resolved".to_owned())
138138
}),
139-
Expr::Wildcard => Err(DataFusionError::Internal(
140-
"Wildcard expressions are not valid in a logical query plan".to_owned(),
141-
)),
139+
Expr::Wildcard => {
140+
// Wildcard do not really have a type and do not appear in projections
141+
Ok(DataType::Null)
142+
}
142143
Expr::QualifiedWildcard { .. } => Err(DataFusionError::Internal(
143144
"QualifiedWildcard expressions are not valid in a logical query plan"
144145
.to_owned(),

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1371,7 +1371,7 @@ pub fn unnest(input: LogicalPlan, column: Column) -> Result<LogicalPlan> {
13711371

13721372
#[cfg(test)]
13731373
mod tests {
1374-
use crate::{expr, expr_fn::exists};
1374+
use crate::{count, expr, expr_fn::exists};
13751375
use arrow::datatypes::{DataType, Field};
13761376
use datafusion_common::{OwnedTableReference, SchemaError, TableReference};
13771377

datafusion/optimizer/src/analyzer.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::count_wildcard_rule::CountWildcardRule;
1819
use crate::rewrite::TreeNodeRewritable;
1920
use datafusion_common::config::ConfigOptions;
2021
use datafusion_common::{DataFusionError, Result};
@@ -49,7 +50,8 @@ impl Default for Analyzer {
4950
impl Analyzer {
5051
/// Create a new analyzer using the recommended list of rules
5152
pub fn new() -> Self {
52-
let rules = vec![];
53+
let rules: Vec<Arc<dyn AnalyzerRule + Send + Sync>> =
54+
vec![Arc::new(CountWildcardRule::new())];
5355
Self::with_rules(rules)
5456
}
5557

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
use crate::analyzer::AnalyzerRule;
2+
use datafusion_common::config::ConfigOptions;
3+
use datafusion_common::Result;
4+
use datafusion_expr::expr::AggregateFunction;
5+
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
6+
use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan};
7+
use std::sync::Arc;
8+
9+
pub struct CountWildcardRule {}
10+
11+
impl CountWildcardRule {
12+
pub fn new() -> Self {
13+
CountWildcardRule {}
14+
}
15+
}
16+
17+
impl AnalyzerRule for CountWildcardRule {
18+
fn analyze(
19+
&self,
20+
plan: &LogicalPlan,
21+
_: &ConfigOptions,
22+
) -> datafusion_common::Result<LogicalPlan> {
23+
let new_plan = match plan {
24+
LogicalPlan::Window(_window) => plan.clone(),
25+
LogicalPlan::Aggregate(aggregate) => {
26+
let aggr_expr = aggregate.clone().aggr_expr;
27+
let aggr_expr = handle_wildcard(aggr_expr).unwrap();
28+
29+
LogicalPlan::Aggregate(
30+
Aggregate::try_new_with_schema(
31+
Arc::new(plan.inputs().get(0).unwrap().clone().clone()),
32+
aggregate.clone().group_expr,
33+
aggr_expr,
34+
plan.schema().clone(),
35+
)
36+
.unwrap(),
37+
)
38+
}
39+
_ => plan.clone(),
40+
};
41+
Ok(new_plan)
42+
}
43+
44+
fn name(&self) -> &str {
45+
"count_wildcard_rule"
46+
}
47+
}
48+
49+
//handle Count(Expr:Wildcard) with DataFrame API
50+
pub fn handle_wildcard(exprs: Vec<Expr>) -> Result<Vec<Expr>> {
51+
let exprs: Vec<Expr> = exprs
52+
.iter()
53+
.map(|expr| match expr {
54+
Expr::AggregateFunction(AggregateFunction {
55+
fun: aggregate_function::AggregateFunction::Count,
56+
args,
57+
distinct,
58+
filter,
59+
}) if args.len() == 1 => match args[0] {
60+
Expr::Wildcard => Expr::AggregateFunction(AggregateFunction {
61+
fun: aggregate_function::AggregateFunction::Count,
62+
args: vec![lit(COUNT_STAR_EXPANSION)],
63+
distinct: *distinct,
64+
filter: filter.clone(),
65+
}),
66+
_ => expr.clone(),
67+
},
68+
_ => expr.clone(),
69+
})
70+
.collect();
71+
Ok(exprs)
72+
}

datafusion/optimizer/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub mod type_coercion;
4545
pub mod unwrap_cast_in_comparison;
4646
pub mod utils;
4747

48+
pub mod count_wildcard_rule;
4849
#[cfg(test)]
4950
pub mod test;
5051

0 commit comments

Comments
 (0)