Skip to content

Commit

Permalink
Improve Robustness of Unparser Testing and Implementation (#9623)
Browse files Browse the repository at this point in the history
* support having, simplify logic

* fmt

* expr rewriting and more...

* lint

* add license

* retry windows ci

* retry windows ci 2

* subquery expr support

* make test even harder

* retry windows

* cargo fmt

* retry windows

* retry windows

* retry windows (last try)

* add comment explaining test
  • Loading branch information
devinjdangelo authored Mar 18, 2024
1 parent 40bf0ea commit 269563a
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 116 deletions.
46 changes: 38 additions & 8 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use arrow_array::{Date32Array, Date64Array};
use arrow_schema::DataType;
use datafusion_common::{
internal_datafusion_err, not_impl_err, Column, Result, ScalarValue,
internal_datafusion_err, not_impl_err, plan_err, Column, Result, ScalarValue,
};
use datafusion_expr::{
expr::{AggregateFunctionDefinition, Alias, InList, ScalarFunction, WindowFunction},
Expand All @@ -40,7 +40,7 @@ use super::Unparser;
/// let expr = col("a").gt(lit(4));
/// let sql = expr_to_sql(&expr).unwrap();
///
/// assert_eq!(format!("{}", sql), "(a > 4)")
/// assert_eq!(format!("{}", sql), "(\"a\" > 4)")
/// ```
pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
let unparser = Unparser::default();
Expand Down Expand Up @@ -151,6 +151,36 @@ impl Unparser<'_> {
order_by: vec![],
}))
}
Expr::ScalarSubquery(subq) => {
let sub_statement = self.plan_to_sql(subq.subquery.as_ref())?;
let sub_query = if let ast::Statement::Query(inner_query) = sub_statement
{
inner_query
} else {
return plan_err!(
"Subquery must be a Query, but found {sub_statement:?}"
);
};
Ok(ast::Expr::Subquery(sub_query))
}
Expr::InSubquery(insubq) => {
let inexpr = Box::new(self.expr_to_sql(insubq.expr.as_ref())?);
let sub_statement =
self.plan_to_sql(insubq.subquery.subquery.as_ref())?;
let sub_query = if let ast::Statement::Query(inner_query) = sub_statement
{
inner_query
} else {
return plan_err!(
"Subquery must be a Query, but found {sub_statement:?}"
);
};
Ok(ast::Expr::InSubquery {
expr: inexpr,
subquery: sub_query,
negated: insubq.negated,
})
}
_ => not_impl_err!("Unsupported expression: {expr:?}"),
}
}
Expand All @@ -169,7 +199,7 @@ impl Unparser<'_> {
pub(super) fn new_ident(&self, str: String) -> ast::Ident {
ast::Ident {
value: str,
quote_style: self.dialect.identifier_quote_style(),
quote_style: Some(self.dialect.identifier_quote_style().unwrap_or('"')),
}
}

Expand Down Expand Up @@ -491,28 +521,28 @@ mod tests {
#[test]
fn expr_to_sql_ok() -> Result<()> {
let tests: Vec<(Expr, &str)> = vec![
((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#),
((col("a") + col("b")).gt(lit(4)), r#"(("a" + "b") > 4)"#),
(
Expr::Column(Column {
relation: Some(TableReference::partial("a", "b")),
name: "c".to_string(),
})
.gt(lit(4)),
r#"(a.b.c > 4)"#,
r#"("a"."b"."c" > 4)"#,
),
(
Expr::Cast(Cast {
expr: Box::new(col("a")),
data_type: DataType::Date64,
}),
r#"CAST(a AS DATETIME)"#,
r#"CAST("a" AS DATETIME)"#,
),
(
Expr::Cast(Cast {
expr: Box::new(col("a")),
data_type: DataType::UInt32,
}),
r#"CAST(a AS INTEGER UNSIGNED)"#,
r#"CAST("a" AS INTEGER UNSIGNED)"#,
),
(
Expr::Literal(ScalarValue::Date64(Some(0))),
Expand Down Expand Up @@ -549,7 +579,7 @@ mod tests {
order_by: None,
null_treatment: None,
}),
"SUM(a)",
r#"SUM("a")"#,
),
(
Expr::AggregateFunction(AggregateFunction {
Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/src/unparser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
mod ast;
mod expr;
mod plan;
mod utils;

pub use expr::expr_to_sql;
pub use plan::plan_to_sql;
Expand Down
87 changes: 34 additions & 53 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@

use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result};
use datafusion_expr::{expr::Alias, Expr, JoinConstraint, JoinType, LogicalPlan};
use sqlparser::ast::{self, Ident, SelectItem};
use sqlparser::ast::{self};

use crate::unparser::utils::unproject_agg_exprs;

use super::{
ast::{
BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
utils::find_agg_node_within_select,
Unparser,
};

Expand All @@ -49,7 +52,7 @@ use super::{
/// .unwrap();
/// let sql = plan_to_sql(&plan).unwrap();
///
/// assert_eq!(format!("{}", sql), "SELECT table.id, table.value FROM table")
/// assert_eq!(format!("{}", sql), "SELECT \"table\".\"id\", \"table\".\"value\" FROM \"table\"")
/// ```
pub fn plan_to_sql(plan: &LogicalPlan) -> Result<ast::Statement> {
let unparser = Unparser::default();
Expand Down Expand Up @@ -132,70 +135,37 @@ impl Unparser<'_> {
// A second projection implies a derived tablefactor
if !select.already_projected() {
// Special handling when projecting an agregation plan
if let LogicalPlan::Aggregate(agg) = p.input.as_ref() {
let mut items = p
.expr
.iter()
.filter(|e| !matches!(e, Expr::AggregateFunction(_)))
.map(|e| self.select_item_to_sql(e))
.collect::<Result<Vec<_>>>()?;

let proj_aggs = p
if let Some(agg) = find_agg_node_within_select(plan, true) {
let items = p
.expr
.iter()
.filter(|e| matches!(e, Expr::AggregateFunction(_)))
.zip(agg.aggr_expr.iter())
.map(|(proj, agg_exp)| {
let sql_agg_expr = self.select_item_to_sql(agg_exp)?;
let maybe_aliased =
if let Expr::Alias(Alias { name, .. }) = proj {
if let SelectItem::UnnamedExpr(aggregation_fun) =
sql_agg_expr
{
SelectItem::ExprWithAlias {
expr: aggregation_fun,
alias: Ident {
value: name.to_string(),
quote_style: None,
},
}
} else {
sql_agg_expr
}
} else {
sql_agg_expr
};
Ok(maybe_aliased)
.map(|proj_expr| {
let unproj = unproject_agg_exprs(proj_expr, agg)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;
items.extend(proj_aggs);

select.projection(items);
select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
));
self.select_to_sql_recursively(
agg.input.as_ref(),
query,
select,
relation,
)
} else {
let items = p
.expr
.iter()
.map(|e| self.select_item_to_sql(e))
.collect::<Result<Vec<_>>>()?;
select.projection(items);
self.select_to_sql_recursively(
p.input.as_ref(),
query,
select,
relation,
)
}
self.select_to_sql_recursively(
p.input.as_ref(),
query,
select,
relation,
)
} else {
let mut derived_builder = DerivedRelationBuilder::default();
derived_builder.lateral(false).alias(None).subquery({
Expand All @@ -213,9 +183,16 @@ impl Unparser<'_> {
}
}
LogicalPlan::Filter(filter) => {
let filter_expr = self.expr_to_sql(&filter.predicate)?;

select.selection(Some(filter_expr));
if let Some(agg) =
find_agg_node_within_select(plan, select.already_projected())
{
let unprojected = unproject_agg_exprs(&filter.predicate, agg)?;
let filter_expr = self.expr_to_sql(&unprojected)?;
select.having(Some(filter_expr));
} else {
let filter_expr = self.expr_to_sql(&filter.predicate)?;
select.selection(Some(filter_expr));
}

self.select_to_sql_recursively(
filter.input.as_ref(),
Expand Down Expand Up @@ -249,9 +226,13 @@ impl Unparser<'_> {
relation,
)
}
LogicalPlan::Aggregate(_agg) => {
not_impl_err!(
"Unsupported aggregation plan not following a projection: {plan:?}"
LogicalPlan::Aggregate(agg) => {
// Aggregate nodes are handled simulatenously with Projection nodes
self.select_to_sql_recursively(
agg.input.as_ref(),
query,
select,
relation,
)
}
LogicalPlan::Distinct(_distinct) => {
Expand Down
84 changes: 84 additions & 0 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Result,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if one exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
/// If an Aggregate node is not found prior to this or at all before reaching the end
/// of the tree, None is returned.
pub(crate) fn find_agg_node_within_select(
plan: &LogicalPlan,
already_projected: bool,
) -> Option<&Aggregate> {
// Note that none of the nodes that have a corresponding agg node can have more
// than 1 input node. E.g. Projection / Filter always have 1 input node.
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};
if let LogicalPlan::Aggregate(agg) = input {
Some(agg)
} else if let LogicalPlan::TableScan(_) = input {
None
} else if let LogicalPlan::Projection(_) = input {
if already_projected {
None
} else {
find_agg_node_within_select(input, true)
}
} else {
find_agg_node_within_select(input, already_projected)
}
}

/// Recursively identify all Column expressions and transform them into the appropriate
/// aggregate expression contained in agg.
///
/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr> {
expr.clone()
.transform(&|sub_expr| {
if let Expr::Column(c) = sub_expr {
// find the column in the agg schmea
if let Ok(n) = agg.schema.index_of_column(&c) {
let unprojected_expr = agg
.group_expr
.iter()
.chain(agg.aggr_expr.iter())
.nth(n)
.unwrap();
Ok(Transformed::yes(unprojected_expr.clone()))
} else {
internal_err!(
"Tried to unproject agg expr not found in provided Aggregate!"
)
}
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}
Loading

0 comments on commit 269563a

Please sign in to comment.