diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 220d7c689..2798d859f 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -36,7 +36,6 @@ use datafusion::optimizer::filter_null_join_keys::FilterNullJoinKeys; use datafusion::optimizer::propagate_empty_relation::PropagateEmptyRelation; use datafusion::optimizer::replace_distinct_aggregate::ReplaceDistinctWithAggregate; use datafusion::optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; -use datafusion::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use datafusion::optimizer::unwrap_cast_in_comparison::UnwrapCastInComparison; use datafusion::optimizer::{AnalyzerRule, OptimizerRule}; use datafusion::physical_plan::ExecutionPlan; @@ -181,7 +180,8 @@ fn optimize_rule_for_unparsing() -> Vec> { // Arc::new(PushDownLimit::new()), // Disable PushDownFilter to avoid the casting for bigquery (datetime/timestamp) column be removed // Arc::new(PushDownFilter::new()), - Arc::new(SingleDistinctToGroupBy::new()), + // Disable SingleDistinctToGroupBy to avoid generate invalid aggregation plan + // Arc::new(SingleDistinctToGroupBy::new()), // Disable SimplifyExpressions to avoid apply some function locally // Arc::new(SimplifyExpressions::new()), Arc::new(UnwrapCastInComparison::new()), diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index fa9b8eca4..2fd5a639c 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -1308,6 +1308,33 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_disable_single_distinct_to_group_by() -> Result<()> { + let ctx = SessionContext::new(); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .build(), + ) + .build(); + let sql = r#"SELECT c_custkey, count(distinct c_name) FROM customer GROUP BY c_custkey"#; + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let result = + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + assert_eq!( + result, + "SELECT customer.c_custkey, count(DISTINCT customer.c_name) FROM (SELECT customer.c_custkey, customer.c_name \ + FROM (SELECT customer.c_custkey AS c_custkey, customer.c_name AS c_name \ + FROM customer) AS customer) AS customer GROUP BY customer.c_custkey" + ); + Ok(()) + } + /// Return a RecordBatch with made up data about customer fn customer() -> RecordBatch { let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));