Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][plan] relax the check for distinct, order by for dataframe #5258

Merged
merged 1 commit into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,48 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_distinct() -> Result<()> {
let t = test_table().await?;
let plan = t
.select(vec![col("c1")])
.unwrap()
.distinct()
.unwrap()
.plan
.clone();

let sql_plan = create_plan("select distinct c1 from aggregate_test_100").await?;

assert_same_plan(&plan, &sql_plan);
Ok(())
}

#[tokio::test]
async fn test_distinct_sort_by() -> Result<()> {
let t = test_table().await?;
let plan = t
.select(vec![col("c1")])
.unwrap()
.distinct()
.unwrap()
.sort(vec![col("c2").sort(true, true)])
.unwrap();
let df_results = plan.clone().collect().await?;
assert_batches_sorted_eq!(
vec![
"+----+", "| c1 |", "+----+", "| a |", "| a |", "| a |", "| a |",
"| a |", "| b |", "| b |", "| b |", "| b |", "| b |", "| c |",
"| c |", "| c |", "| c |", "| c |", "| d |", "| d |", "| d |",
"| d |", "| d |", "| e |", "| e |", "| e |", "| e |", "| e |",
"+----+",
],
&df_results
);

Ok(())
}

#[tokio::test]
async fn join() -> Result<()> {
let left = test_table().await?.select_columns(&["c1", "c2"])?;
Expand Down
34 changes: 22 additions & 12 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use arrow::{
record_batch::RecordBatch,
};
use datafusion::from_slice::FromSlice;
use datafusion_common::DataFusionError;
use std::sync::Arc;

use datafusion::dataframe::DataFrame;
Expand Down Expand Up @@ -146,18 +145,29 @@ async fn sort_on_distinct_unprojected_columns() -> Result<()> {

let ctx = SessionContext::new();
ctx.register_batch("t", batch).unwrap();
let df = ctx
.table("t")
.await
.unwrap()
.select(vec![col("a")])
.unwrap()
.distinct()
.unwrap()
.sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))])
.unwrap();
let results = df.collect().await.unwrap();

assert!(matches!(
ctx.table("t")
.await
.unwrap()
.select(vec![col("a")])
.unwrap()
.distinct()
.unwrap()
.sort(vec![Expr::Sort(Sort::new(Box::new(col("b")), false, true))]),
Err(DataFusionError::Plan(_))
));
#[rustfmt::skip]
let expected = vec![
"+-----+",
"| a |",
"+-----+",
"| 100 |",
"| 10 |",
"| 1 |",
"+-----+",
];
assert_batches_eq!(expected, &results);
Ok(())
}

Expand Down
28 changes: 8 additions & 20 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,27 +406,15 @@ impl LogicalPlanBuilder {
Ok(())
})?;

// if current plan is distinct or current plan is repartition and its child plan is distinct,
// then this plan is a select distinct plan
let is_select_distinct = match self.plan {
LogicalPlan::Distinct(_) => true,
LogicalPlan::Repartition(Repartition { ref input, .. }) => {
matches!(input.as_ref(), &LogicalPlan::Distinct(_))
}
_ => false,
};
self.create_sort_plan(exprs, missing_cols)
}

// for select distinct, order by expressions must exist in select list
if is_select_distinct && !missing_cols.is_empty() {
let missing_col_names = missing_cols
.iter()
.map(|col| col.flat_name())
.collect::<String>();
let error_msg = format!(
"For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list",
);
return Err(DataFusionError::Plan(error_msg));
}
pub fn create_sort_plan(
self,
exprs: impl IntoIterator<Item = impl Into<Expr>> + Clone,
missing_cols: Vec<Column>,
) -> Result<Self> {
let schema = self.plan.schema();

if missing_cols.is_empty() {
return Ok(Self::from(LogicalPlan::Sort(Sort {
Expand Down
53 changes: 49 additions & 4 deletions datafusion/sql/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use crate::utils::normalize_ident;
use datafusion_common::{DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
use datafusion_common::{Column, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::rewrite_sort_cols_by_aggs;
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder, Repartition};
use sqlparser::ast::{Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query};

use sqlparser::parser::ParserError::ParserError;
Expand Down Expand Up @@ -150,11 +151,55 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
return Ok(plan);
}

let order_by_rex = order_by
let mut order_by_rex = order_by
.into_iter()
.map(|e| self.order_by_to_sort_expr(e, plan.schema()))
.collect::<Result<Vec<_>>>()?;

LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build()
order_by_rex = rewrite_sort_cols_by_aggs(order_by_rex, &plan)?;
let schema = plan.schema();

// if current plan is distinct or current plan is repartition and its child plan is distinct,
// then this plan is a select distinct plan
let is_select_distinct = match plan {
LogicalPlan::Distinct(_) => true,
LogicalPlan::Repartition(Repartition { ref input, .. }) => {
matches!(input.as_ref(), &LogicalPlan::Distinct(_))
}
_ => false,
};

let mut missing_cols: Vec<Column> = vec![];
// Collect sort columns that are missing in the input plan's schema
order_by_rex
.clone()
.into_iter()
.try_for_each::<_, Result<()>>(|expr| {
let columns = expr.to_columns()?;

columns.into_iter().for_each(|c| {
if schema.field_from_column(&c).is_err() {
missing_cols.push(c);
}
});

Ok(())
})?;

// for select distinct, order by expressions must exist in select list
if is_select_distinct && !missing_cols.is_empty() {
let missing_col_names = missing_cols
.iter()
.map(|col| col.flat_name())
.collect::<String>();
let error_msg = format!(
"For SELECT DISTINCT, ORDER BY expressions {missing_col_names} must appear in select list",
);
return Err(DataFusionError::Plan(error_msg));
}

LogicalPlanBuilder::from(plan)
.create_sort_plan(order_by_rex, missing_cols)?
.build()
}
}