Skip to content

Commit a622b02

Browse files
Support unparsing UNION for distinct results (#75)
UPSTREAM NOTE: This was submitted upstream: apache#15814 and will be in DF 48
1 parent 391fb71 commit a622b02

File tree

3 files changed

+44
-1
lines changed

3 files changed

+44
-1
lines changed

datafusion/sql/src/unparser/ast.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ pub struct QueryBuilder {
3131
fetch: Option<ast::Fetch>,
3232
locks: Vec<ast::LockClause>,
3333
for_clause: Option<ast::ForClause>,
34+
// If true, we need to unparse LogicalPlan::Union as a SQL `UNION` rather than a `UNION ALL`.
35+
distinct_union: bool,
3436
}
3537

3638
#[allow(dead_code)]
@@ -77,6 +79,13 @@ impl QueryBuilder {
7779
self.for_clause = value;
7880
self
7981
}
82+
pub fn distinct_union(&mut self) -> &mut Self {
83+
self.distinct_union = true;
84+
self
85+
}
86+
pub fn is_distinct_union(&self) -> bool {
87+
self.distinct_union
88+
}
8089
pub fn build(&self) -> Result<ast::Query, BuilderError> {
8190
let order_by = if self.order_by.is_empty() {
8291
None
@@ -115,6 +124,7 @@ impl QueryBuilder {
115124
fetch: Default::default(),
116125
locks: Default::default(),
117126
for_clause: Default::default(),
127+
distinct_union: false,
118128
}
119129
}
120130
}

datafusion/sql/src/unparser/plan.rs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,23 @@ impl Unparser<'_> {
594594
false,
595595
);
596596
}
597+
598+
// If this distinct is the parent of a Union and we're in a query context,
599+
// then we need to unparse as a `UNION` rather than a `UNION ALL`.
600+
if let Distinct::All(input) = distinct {
601+
if matches!(input.as_ref(), LogicalPlan::Union(_)) {
602+
if let Some(query_mut) = query.as_mut() {
603+
query_mut.distinct_union();
604+
return self.select_to_sql_recursively(
605+
input.as_ref(),
606+
query,
607+
select,
608+
relation,
609+
);
610+
}
611+
}
612+
}
613+
597614
let (select_distinct, input) = match distinct {
598615
Distinct::All(input) => (ast::Distinct::Distinct, input.as_ref()),
599616
Distinct::On(on) => {
@@ -785,14 +802,23 @@ impl Unparser<'_> {
785802
return internal_err!("UNION operator requires at least 2 inputs");
786803
}
787804

805+
let set_quantifier =
806+
if query.as_ref().is_some_and(|q| q.is_distinct_union()) {
807+
// Setting the SetQuantifier to None will unparse as a `UNION`
808+
// rather than a `UNION ALL`.
809+
ast::SetQuantifier::None
810+
} else {
811+
ast::SetQuantifier::All
812+
};
813+
788814
// Build the union expression tree bottom-up by reversing the order
789815
// note that we are also swapping left and right inputs because of the rev
790816
let union_expr = input_exprs
791817
.into_iter()
792818
.rev()
793819
.reduce(|a, b| SetExpr::SetOperation {
794820
op: ast::SetOperator::Union,
795-
set_quantifier: ast::SetQuantifier::All,
821+
set_quantifier,
796822
left: Box::new(b),
797823
right: Box::new(a),
798824
})

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,13 @@ fn roundtrip_statement() -> Result<()> {
166166
UNION ALL
167167
SELECT j3_string AS col1, j3_id AS id FROM j3
168168
) AS subquery GROUP BY col1, id ORDER BY col1 ASC, id ASC"#,
169+
r#"SELECT col1, id FROM (
170+
SELECT j1_string AS col1, j1_id AS id FROM j1
171+
UNION
172+
SELECT j2_string AS col1, j2_id AS id FROM j2
173+
UNION
174+
SELECT j3_string AS col1, j3_id AS id FROM j3
175+
) AS subquery ORDER BY col1 ASC, id ASC"#,
169176
"SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
170177
last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
171178
first_name from person",

0 commit comments

Comments
 (0)