diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 26fe5c051204..36135bd1eb36 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -1072,6 +1072,48 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_distinct() -> Result<()> { + let t = test_table().await?; + let plan = t + .select(vec![col("c1")]) + .unwrap() + .distinct() + .unwrap() + .plan + .clone(); + + let sql_plan = create_plan("select distinct c1 from aggregate_test_100").await?; + + assert_same_plan(&plan, &sql_plan); + Ok(()) + } + + #[tokio::test] + async fn test_distinct_sort_by() -> Result<()> { + let t = test_table().await?; + let plan = t + .select(vec![col("c1")]) + .unwrap() + .distinct() + .unwrap() + .sort(vec![col("c2").sort(true, true)]) + .unwrap(); + let df_results = plan.clone().collect().await?; + assert_batches_sorted_eq!( + vec![ + "+----+", "| c1 |", "+----+", "| a |", "| a |", "| a |", "| a |", + "| a |", "| b |", "| b |", "| b |", "| b |", "| b |", "| c |", + "| c |", "| c |", "| c |", "| c |", "| d |", "| d |", "| d |", + "| d |", "| d |", "| e |", "| e |", "| e |", "| e |", "| e |", + "+----+", + ], + &df_results + ); + + Ok(()) + } + #[tokio::test] async fn join() -> Result<()> { let left = test_table().await?.select_columns(&["c1", "c2"])?; diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index f7274fdabf80..9c71be7bb0e2 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -24,7 +24,6 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion::from_slice::FromSlice; -use datafusion_common::DataFusionError; use std::sync::Arc; use datafusion::dataframe::DataFrame; @@ -146,18 +145,29 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> { let ctx = SessionContext::new(); ctx.register_batch("t", batch).unwrap(); + let df = ctx + .table("t") + .await + .unwrap() + .select(vec![col("a")]) + .unwrap() + .distinct() + .unwrap() + .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]) + .unwrap(); + let results = df.collect().await.unwrap(); - assert!(matches!( - ctx.table("t") - .await - .unwrap() - .select(vec![col("a")]) - .unwrap() - .distinct() - .unwrap() - .sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]), - Err(DataFusionError::Plan(_)) - )); + #[rustfmt::skip] + let expected = vec![ + "+-----+", + "| a |", + "+-----+", + "| 100 |", + "| 10 |", + "| 1 |", + "+-----+", + ]; + assert_batches_eq!(expected, &results); Ok(()) } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index e1092b96b355..f979e1f76f98 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -406,27 +406,15 @@ impl LogicalPlanBuilder { Ok(()) })?; - // if current plan is distinct or current plan is repartition and its child plan is distinct, - // then this plan is a select distinct plan - let is_select_distinct = match self.plan { - LogicalPlan::Distinct(_) => true, - LogicalPlan::Repartition(Repartition { ref input, .. }) => { - matches!(input.as_ref(), &LogicalPlan::Distinct(_)) - } - _ => false, - }; + self.create_sort_plan(exprs, missing_cols) + } - // for select distinct, order by expressions must exist in select list - if is_select_distinct && !missing_cols.is_empty() { - let missing_col_names = missing_cols - .iter() - .map(|col| col.flat_name()) - .collect::(); - let error_msg = format!( - "For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list", - ); - return Err(DataFusionError::Plan(error_msg)); - } + pub fn create_sort_plan( + self, + exprs: impl IntoIterator> + Clone, + missing_cols: Vec, + ) -> Result { + let schema = self.plan.schema(); if missing_cols.is_empty() { return Ok(Self::from(LogicalPlan::Sort(Sort { diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index eb7ece87d6b1..c59c42e93c0d 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -17,8 +17,9 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use crate::utils::normalize_ident; -use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use datafusion_common::{Column, DFSchema, DataFusionError, Result, ScalarValue}; +use datafusion_expr::expr_rewriter::rewrite_sort_cols_by_aggs; +use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Repartition}; use sqlparser::ast::{Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query}; use sqlparser::parser::ParserError::ParserError; @@ -150,11 +151,55 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(plan); } - let order_by_rex = order_by + let mut order_by_rex = order_by .into_iter() .map(|e| self.order_by_to_sort_expr(e, plan.schema())) .collect::>>()?; - LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() + order_by_rex = rewrite_sort_cols_by_aggs(order_by_rex, &plan)?; + let schema = plan.schema(); + + // if current plan is distinct or current plan is repartition and its child plan is distinct, + // then this plan is a select distinct plan + let is_select_distinct = match plan { + LogicalPlan::Distinct(_) => true, + LogicalPlan::Repartition(Repartition { ref input, .. }) => { + matches!(input.as_ref(), &LogicalPlan::Distinct(_)) + } + _ => false, + }; + + let mut missing_cols: Vec = vec![]; + // Collect sort columns that are missing in the input plan's schema + order_by_rex + .clone() + .into_iter() + .try_for_each::<_, Result<()>>(|expr| { + let columns = expr.to_columns()?; + + columns.into_iter().for_each(|c| { + if schema.field_from_column(&c).is_err() { + missing_cols.push(c); + } + }); + + Ok(()) + })?; + + // for select distinct, order by expressions must exist in select list + if is_select_distinct && !missing_cols.is_empty() { + let missing_col_names = missing_cols + .iter() + .map(|col| col.flat_name()) + .collect::(); + let error_msg = format!( + "For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list", + ); + return Err(DataFusionError::Plan(error_msg)); + } + + LogicalPlanBuilder::from(plan) + .create_sort_plan(order_by_rex, missing_cols)? + .build() } }