Skip to content

Commit

Permalink
support having, simplify logic
Browse files Browse the repository at this point in the history
  • Loading branch information
devinjdangelo committed Mar 15, 2024
1 parent 219de5f commit 3667b49
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 71 deletions.
103 changes: 32 additions & 71 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ use super::{
ast::{
BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
Unparser,
}, Unparser
};

/// Convert a DataFusion [`LogicalPlan`] to `sqlparser::ast::Statement`
Expand Down Expand Up @@ -131,71 +130,20 @@ impl Unparser<'_> {
LogicalPlan::Projection(p) => {
// A second projection implies a derived tablefactor
if !select.already_projected() {
// Special handling when projecting an agregation plan
if let LogicalPlan::Aggregate(agg) = p.input.as_ref() {
let mut items = p
.expr
.iter()
.filter(|e| !matches!(e, Expr::AggregateFunction(_)))
.map(|e| self.select_item_to_sql(e))
.collect::<Result<Vec<_>>>()?;

let proj_aggs = p
.expr
.iter()
.filter(|e| matches!(e, Expr::AggregateFunction(_)))
.zip(agg.aggr_expr.iter())
.map(|(proj, agg_exp)| {
let sql_agg_expr = self.select_item_to_sql(agg_exp)?;
let maybe_aliased =
if let Expr::Alias(Alias { name, .. }) = proj {
if let SelectItem::UnnamedExpr(aggregation_fun) =
sql_agg_expr
{
SelectItem::ExprWithAlias {
expr: aggregation_fun,
alias: Ident {
value: name.to_string(),
quote_style: None,
},
}
} else {
sql_agg_expr
}
} else {
sql_agg_expr
};
Ok(maybe_aliased)
})
.collect::<Result<Vec<_>>>()?;
items.extend(proj_aggs);
select.projection(items);
select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
));
self.select_to_sql_recursively(
agg.input.as_ref(),
query,
select,
relation,
)
} else {
let items = p
.expr
.iter()
.map(|e| self.select_item_to_sql(e))
.collect::<Result<Vec<_>>>()?;
select.projection(items);
self.select_to_sql_recursively(
p.input.as_ref(),
query,
select,
relation,
)
}

let items = p
.expr
.iter()
.map(|e| self.select_item_to_sql(e))
.collect::<Result<Vec<_>>>()?;
select.projection(items);
self.select_to_sql_recursively(
p.input.as_ref(),
query,
select,
relation,
)

} else {
let mut derived_builder = DerivedRelationBuilder::default();
derived_builder.lateral(false).alias(None).subquery({
Expand All @@ -215,7 +163,11 @@ impl Unparser<'_> {
LogicalPlan::Filter(filter) => {
let filter_expr = self.expr_to_sql(&filter.predicate)?;

select.selection(Some(filter_expr));
if let LogicalPlan::Aggregate(_) = filter.input.as_ref(){
select.having(Some(filter_expr));
} else {
select.selection(Some(filter_expr));
}

self.select_to_sql_recursively(
filter.input.as_ref(),
Expand Down Expand Up @@ -249,9 +201,18 @@ impl Unparser<'_> {
relation,
)
}
LogicalPlan::Aggregate(_agg) => {
not_impl_err!(
"Unsupported aggregation plan not following a projection: {plan:?}"
LogicalPlan::Aggregate(agg) => {
select.group_by(ast::GroupByExpr::Expressions(
agg.group_expr
.iter()
.map(|expr| self.expr_to_sql(expr))
.collect::<Result<Vec<_>>>()?,
));
self.select_to_sql_recursively(
agg.input.as_ref(),
query,
select,
relation,
)
}
LogicalPlan::Distinct(_distinct) => {
Expand Down
12 changes: 12 additions & 0 deletions datafusion/sql/tests/sql_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4580,6 +4580,18 @@ fn roundtrip_statement() {
"select id, count(*), first_name from person group by first_name, id",
"SELECT person.id, COUNT(*), person.first_name FROM person GROUP BY person.first_name, person.id"
),
(
"select id, sum(age), first_name from person group by first_name, id",
"SELECT person.id, SUM(person.age), person.first_name FROM person GROUP BY person.first_name, person.id"
),
("select id, count(*), first_name
from person
where id!=3 and first_name=='test'
group by first_name, id
having count(*)>5 and count(*)<10
order by count(*)",
"SELECT person.id, COUNT(*), person.first_name FROM person WHERE ((person.id <> 3) AND (person.first_name = 'test')) GROUP BY person.first_name, person.id HAVING ((COUNT(*) > 5) AND (COUNT(*) < 10)) ORDER BY COUNT(*) ASC NULLS LAST"
)
];

let roundtrip = |sql: &str| -> Result<String> {
Expand Down

0 comments on commit 3667b49

Please sign in to comment.