Skip to content

Commit

Permalink
expr rewriting and more...
Browse files Browse the repository at this point in the history
  • Loading branch information
devinjdangelo committed Mar 17, 2024
1 parent 998323a commit ed678ff
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 91 deletions.
14 changes: 7 additions & 7 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use super::Unparser;
/// let expr = col("a").gt(lit(4));
/// let sql = expr_to_sql(&expr).unwrap();
///
/// assert_eq!(format!("{}", sql), "(a > 4)")
/// assert_eq!(format!("{}", sql), "(\"a\" > 4)")
/// ```
pub fn expr_to_sql(expr: &Expr) -> Result<ast::Expr> {
let unparser = Unparser::default();
Expand Down Expand Up @@ -169,7 +169,7 @@ impl Unparser<'_> {
pub(super) fn new_ident(&self, str: String) -> ast::Ident {
ast::Ident {
value: str,
quote_style: self.dialect.identifier_quote_style(),
quote_style: Some(self.dialect.identifier_quote_style().unwrap_or('"')),
}
}

Expand Down Expand Up @@ -491,28 +491,28 @@ mod tests {
#[test]
fn expr_to_sql_ok() -> Result<()> {
let tests: Vec<(Expr, &str)> = vec![
((col("a") + col("b")).gt(lit(4)), r#"((a + b) > 4)"#),
((col("a") + col("b")).gt(lit(4)), r#"(("a" + "b") > 4)"#),
(
Expr::Column(Column {
relation: Some(TableReference::partial("a", "b")),
name: "c".to_string(),
})
.gt(lit(4)),
r#"(a.b.c > 4)"#,
r#"("a"."b"."c" > 4)"#,
),
(
Expr::Cast(Cast {
expr: Box::new(col("a")),
data_type: DataType::Date64,
}),
r#"CAST(a AS DATETIME)"#,
r#"CAST("a" AS DATETIME)"#,
),
(
Expr::Cast(Cast {
expr: Box::new(col("a")),
data_type: DataType::UInt32,
}),
r#"CAST(a AS INTEGER UNSIGNED)"#,
r#"CAST("a" AS INTEGER UNSIGNED)"#,
),
(
Expr::Literal(ScalarValue::Date64(Some(0))),
Expand Down Expand Up @@ -549,7 +549,7 @@ mod tests {
order_by: None,
null_treatment: None,
}),
"SUM(a)",
r#"SUM("a")"#,
),
(
Expr::AggregateFunction(AggregateFunction {
Expand Down
1 change: 1 addition & 0 deletions datafusion/sql/src/unparser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
mod ast;
mod expr;
mod plan;
mod utils;

pub use expr::expr_to_sql;
pub use plan::plan_to_sql;
Expand Down
66 changes: 48 additions & 18 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,25 @@
// specific language governing permissions and limitations
// under the License.

use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result};
use datafusion_expr::{expr::Alias, Expr, JoinConstraint, JoinType, LogicalPlan};
use sqlparser::ast;
use datafusion_common::{
internal_err, not_impl_err, plan_err,
tree_node::{Transformed, TreeNode, TreeNodeRecursion},
DataFusionError, Result,
};
use datafusion_expr::{
expr::Alias, utils::find_column_exprs, Expr, ExprSchemable, JoinConstraint, JoinType,
LogicalPlan,
};
use sqlparser::ast::{self, Ident, SelectItem};

use crate::unparser::utils::unproject_agg_exprs;

use super::{
ast::{
BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
utils::find_agg_node_within_select,
Unparser,
};

Expand All @@ -49,7 +59,7 @@ use super::{
/// .unwrap();
/// let sql = plan_to_sql(&plan).unwrap();
///
/// assert_eq!(format!("{}", sql), "SELECT table.id, table.value FROM table")
/// assert_eq!(format!("{}", sql), "SELECT \"table\".\"id\", \"table\".\"value\" FROM \"table\"")
/// ```
pub fn plan_to_sql(plan: &LogicalPlan) -> Result<ast::Statement> {
let unparser = Unparser::default();
Expand Down Expand Up @@ -131,12 +141,32 @@ impl Unparser<'_> {
LogicalPlan::Projection(p) => {
// A second projection implies a derived tablefactor
if !select.already_projected() {
let items = p
.expr
.iter()
.map(|e| self.select_item_to_sql(e))
.collect::<Result<Vec<_>>>()?;
select.projection(items);
// Special handling when projecting an agregation plan
if let Some(agg) = find_agg_node_within_select(plan, true) {
let items = p
.expr
.iter()
.map(|proj_expr| {
let unproj = unproject_agg_exprs(proj_expr, agg)?;
self.select_item_to_sql(&unproj)
})
.collect::<Result<Vec<_>>>()?;

select.projection(items);
select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
));
} else {
let items = p
.expr
.iter()
.map(|e| self.select_item_to_sql(e))
.collect::<Result<Vec<_>>>()?;
select.projection(items);
}
self.select_to_sql_recursively(
p.input.as_ref(),
query,
Expand All @@ -160,11 +190,16 @@ impl Unparser<'_> {
}
}
LogicalPlan::Filter(filter) => {
let filter_expr = self.expr_to_sql(&filter.predicate)?;
println!("filter {plan:?}");

if let LogicalPlan::Aggregate(_) = filter.input.as_ref() {
if let Some(agg) =
find_agg_node_within_select(plan, select.already_projected())
{
let unprojected = unproject_agg_exprs(&filter.predicate, agg)?;
let filter_expr = self.expr_to_sql(&unprojected)?;
select.having(Some(filter_expr));
} else {
let filter_expr = self.expr_to_sql(&filter.predicate)?;
select.selection(Some(filter_expr));
}

Expand Down Expand Up @@ -201,12 +236,7 @@ impl Unparser<'_> {
)
}
LogicalPlan::Aggregate(agg) => {
select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
));
// Aggregate nodes are handled simulatenously with Projection nodes
self.select_to_sql_recursively(
agg.input.as_ref(),
query,
Expand Down
64 changes: 64 additions & 0 deletions datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
Result,
};
use datafusion_expr::{Aggregate, Expr, LogicalPlan, Projection};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if one exists
/// prior to encountering a Join, TableScan, or subquery node. If an Aggregate node is not found
/// prior to this or at all before reaching the end of the tree, None is returned.
pub(crate) fn find_agg_node_within_select(
plan: &LogicalPlan,
already_projected: bool,
) -> Option<&Aggregate> {
let input = plan.inputs();
let input = if input.len() > 1 {
return None;
} else {
input.first()?
};
if let LogicalPlan::Aggregate(agg) = input {
Some(agg)
} else if let LogicalPlan::TableScan(_) = input {
None
} else if let LogicalPlan::Projection(_) = input {
if already_projected {
None
} else {
find_agg_node_within_select(input, true)
}
} else {
find_agg_node_within_select(input, already_projected)
}
}

/// Recursively identify all Column expressions and transform them into the appropriate
/// aggregate expression contained in agg.
///
/// For example, if expr contains the column expr "COUNT(*)" it will be transformed
/// into an actual aggregate expression COUNT(*) as identified in the aggregate node.
pub(crate) fn unproject_agg_exprs(expr: &Expr, agg: &Aggregate) -> Result<Expr> {
expr.clone()
.transform(&|sub_expr| {
if let Expr::Column(c) = sub_expr {
// find the column in the agg schmea
if let Ok(n) = agg.schema.index_of_column(&c) {
let unprojected_expr = agg
.group_expr
.iter()
.chain(agg.aggr_expr.iter())
.nth(n)
.unwrap();
Ok(Transformed::yes(unprojected_expr.clone()))
} else {
internal_err!(
"Tried to unproject agg expr not found in provided Aggregate!"
)
}
} else {
Ok(Transformed::no(sub_expr))
}
})
.map(|e| e.data)
}
Loading

0 comments on commit ed678ff

Please sign in to comment.