Skip to content

Commit 4c52eb7

Browse files
committed
fix: use indexmap for deterministic output
1 parent 7f9253b commit 4c52eb7

File tree

1 file changed

+56
-46
lines changed

1 file changed

+56
-46
lines changed

datafusion/optimizer/src/decorrelate_general.rs

Lines changed: 56 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ use datafusion_expr::expr::{self, Exists, InSubquery};
3434
use datafusion_expr::select_expr::SelectExpr;
3535
use datafusion_expr::utils::conjunction;
3636
use datafusion_expr::{
37-
binary_expr, col, expr_fn, lit, not, Aggregate, BinaryExpr, DependentJoin,
38-
EmptyRelation, Expr, ExprSchemable, Filter, JoinType, LogicalPlan,
37+
binary_expr, case, col, expr_fn, lit, not, when, Aggregate, BinaryExpr,
38+
DependentJoin, EmptyRelation, Expr, ExprSchemable, Filter, JoinType, LogicalPlan,
3939
LogicalPlanBuilder, Operator, Projection,
4040
};
4141

@@ -70,7 +70,7 @@ pub struct DependentJoinDecorrelator {
7070
// top-most subquery decorrelation has depth 1 and so on
7171
depth: usize,
7272
// hashmap of correlated column by depth
73-
correlated_map: HashMap<usize, Vec<CorrelatedColumnInfo>>,
73+
correlated_map: IndexMap<usize, Vec<CorrelatedColumnInfo>>,
7474
// check if we have to replace any COUNT aggregates into "CASE WHEN X IS NULL THEN 0 ELSE COUNT END"
7575
// store a mapping between a expr and its original index in the loglan output
7676
replacement_map: IndexMap<String, Expr>,
@@ -101,21 +101,26 @@ fn natural_join(
101101
)
102102
})
103103
.collect();
104+
let require_dedup = !join_exprs.is_empty();
104105

105106
builder = builder.join(
106107
right,
107108
join_type,
108109
(Vec::<Column>::new(), Vec::<Column>::new()),
109110
conjunction(join_exprs).or(Some(lit(true))),
110111
)?;
111-
let remain_cols = builder.schema().columns().into_iter().filter_map(|c| {
112-
if exclude_cols.contains(&c) {
113-
None
114-
} else {
115-
Some(Expr::Column(c))
116-
}
117-
});
118-
builder.project(remain_cols)
112+
if require_dedup {
113+
let remain_cols = builder.schema().columns().into_iter().filter_map(|c| {
114+
if exclude_cols.contains(&c) {
115+
None
116+
} else {
117+
Some(Expr::Column(c))
118+
}
119+
});
120+
builder.project(remain_cols)
121+
} else {
122+
Ok(builder)
123+
}
119124
}
120125

121126
impl DependentJoinDecorrelator {
@@ -156,7 +161,7 @@ impl DependentJoinDecorrelator {
156161
domains: IndexSet::new(),
157162
delim_types: vec![],
158163
is_initial: true,
159-
correlated_map: HashMap::new(),
164+
correlated_map: IndexMap::new(),
160165
replacement_map: IndexMap::new(),
161166
any_join: true,
162167
delim_scan_id: 0,
@@ -165,7 +170,7 @@ impl DependentJoinDecorrelator {
165170
}
166171
fn new(
167172
correlated_columns: &Vec<(usize, Column, DataType)>,
168-
parent_correlated_columns: &HashMap<usize, Vec<CorrelatedColumnInfo>>,
173+
parent_correlated_columns: &IndexMap<usize, Vec<CorrelatedColumnInfo>>,
169174
is_initial: bool,
170175
any_join: bool,
171176
delim_scan_id: usize,
@@ -271,10 +276,10 @@ impl DependentJoinDecorrelator {
271276
// the DELIM join has happend somewhere
272277
// and the new correlated columns now has new name
273278
// using the delim_join side's name
274-
Self::rewrite_correlated_columns(
275-
&mut correlated_columns,
276-
self.delim_scan_relation_name(),
277-
);
279+
// Self::rewrite_correlated_columns(
280+
// &mut correlated_columns,
281+
// self.delim_scan_relation_name(),
282+
// );
278283
new_left
279284
} else {
280285
self.init(node);
@@ -767,14 +772,9 @@ impl DependentJoinDecorrelator {
767772
// Transformed::yes(Expr::Literal(ScalarValue::Int64(Some(0))))
768773
if func.name() == "count" {
769774
let expr_name = agg_expr.to_string();
770-
let expr_to_replace = Expr::Case(expr::Case {
771-
expr: None,
772-
when_then_expr: vec![(
773-
Box::new(agg_expr.clone().is_null()),
774-
Box::new(lit(0)),
775-
)],
776-
else_expr: Some(Box::new(agg_expr.clone())),
777-
});
775+
let expr_to_replace =
776+
when(agg_expr.clone().is_null(), lit(0))
777+
.otherwise(agg_expr.clone())?;
778778
self.replacement_map
779779
.insert(expr_name, expr_to_replace);
780780
continue;
@@ -1769,6 +1769,7 @@ mod tests {
17691769
use super::DependentJoinRewriter;
17701770

17711771
use crate::test::{test_table_scan_with_name, test_table_with_columns};
1772+
use crate::Optimizer;
17721773
use crate::{
17731774
assert_optimized_plan_eq_display_indent_snapshot,
17741775
decorrelate_general::Decorrelation, OptimizerConfig, OptimizerContext,
@@ -1785,6 +1786,15 @@ mod tests {
17851786
use datafusion_functions_aggregate::{count::count, sum::sum};
17861787
use insta::assert_snapshot;
17871788
use std::sync::Arc;
1789+
fn print_graphviz(plan: &LogicalPlan) {
1790+
let rule: Arc<dyn OptimizerRule + Send + Sync> = Arc::new(Decorrelation::new());
1791+
let optimizer = Optimizer::with_rules(vec![rule]);
1792+
let optimized_plan = optimizer
1793+
.optimize(plan.clone(), &OptimizerContext::new(), |_, _| {})
1794+
.expect("failed to optimize plan");
1795+
let formatted_plan = optimized_plan.display_indent_schema();
1796+
println!("{}", optimized_plan.display_graphviz());
1797+
}
17881798

17891799
macro_rules! assert_decorrelate {
17901800
(
@@ -2808,6 +2818,8 @@ mod tests {
28082818
.and(scalar_subquery(scalar_sq_level1).eq(col("outer_table.a"))),
28092819
)?
28102820
.build()?;
2821+
print_graphviz(&plan);
2822+
28112823
// Projection: outer_table.a, outer_table.b, outer_table.c
28122824
// Filter: outer_table.a > Int32(1) AND __scalar_sq_2.output = outer_table.a
28132825
// DependentJoin on [outer_table.a lvl 2, outer_table.c lvl 1] with expr (<subquery>) depth 1
@@ -2832,27 +2844,25 @@ mod tests {
28322844
Projection: count(inner_table_lv1.a), delim_scan_4.outer_table_c, delim_scan_4.outer_table_a [count(inner_table_lv1.a):Int64, outer_table_c:UInt32;N, outer_table_a:UInt32;N]
28332845
Aggregate: groupBy=[[delim_scan_4.outer_table_a, delim_scan_4.outer_table_c]], aggr=[[count(inner_table_lv1.a)]] [outer_table_a:UInt32;N, outer_table_c:UInt32;N, count(inner_table_lv1.a):Int64]
28342846
Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, delim_scan_4.outer_table_a, delim_scan_4.outer_table_c [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2835-
Filter: inner_table_lv1.c = delim_scan_4.outer_table_c AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, outer_table_c:UInt32;N, outer_table_a:UInt32;N, delim_scan_2_b:UInt32;N, __scalar_sq_1.output:Int32;N]
2836-
Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, delim_scan_2.outer_table_a, delim_scan_2.outer_table_c, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, delim_scan_4.outer_table_c, delim_scan_4.outer_table_a, delim_scan_4.delim_scan_2_b, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END AS __scalar_sq_1.output [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, outer_table_c:UInt32;N, outer_table_a:UInt32;N, delim_scan_2_b:UInt32;N, __scalar_sq_1.output:Int32;N]
2837-
Left Join: Filter: inner_table_lv1.b IS NOT DISTINCT FROM delim_scan_4.b [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, outer_table_c:UInt32;N, outer_table_a:UInt32;N, delim_scan_2_b:UInt32;N]
2838-
Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, delim_scan_2.outer_table_a, delim_scan_2.outer_table_c [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2839-
Inner Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2840-
TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]
2841-
SubqueryAlias: delim_scan_2 [outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2842-
DelimGet: outer_table.a, outer_table.c [outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2843-
Projection: CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, delim_scan_4.outer_table_c, delim_scan_4.outer_table_a, delim_scan_4.delim_scan_2_b [CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32, outer_table_c:UInt32;N, outer_table_a:UInt32;N, delim_scan_2_b:UInt32;N]
2844-
Projection: count(inner_table_lv2.a), delim_scan_4.outer_table_c, delim_scan_4.outer_table_a, delim_scan_4.delim_scan_2_b [count(inner_table_lv2.a):Int64, outer_table_c:UInt32;N, outer_table_a:UInt32;N, delim_scan_2_b:UInt32;N]
2845-
Inner Join: Filter: delim_scan_4.delim_scan_2_b IS NOT DISTINCT FROM delim_scan_3.delim_scan_2_b AND delim_scan_4.outer_table_a IS NOT DISTINCT FROM delim_scan_3.outer_table_a AND delim_scan_4.outer_table_c IS NOT DISTINCT FROM delim_scan_3.outer_table_c [count(inner_table_lv2.a):Int64, outer_table_c:UInt32;N, outer_table_a:UInt32;N, delim_scan_2_b:UInt32;N, delim_scan_2_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2846-
Projection: count(inner_table_lv2.a), delim_scan_4.outer_table_c, delim_scan_4.outer_table_a, delim_scan_4.delim_scan_2_b [count(inner_table_lv2.a):Int64, outer_table_c:UInt32;N, outer_table_a:UInt32;N, delim_scan_2_b:UInt32;N]
2847-
Aggregate: groupBy=[[delim_scan_4.delim_scan_2_b, delim_scan_4.outer_table_a, delim_scan_4.outer_table_c]], aggr=[[count(inner_table_lv2.a)]] [delim_scan_2_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N, count(inner_table_lv2.a):Int64]
2848-
Filter: inner_table_lv2.a = delim_scan_4.outer_table_a AND inner_table_lv2.b = outer_ref(inner_table_lv1.b) [a:UInt32, b:UInt32, c:UInt32, delim_scan_2_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2849-
Projection: inner_table_lv2.a, inner_table_lv2.b, inner_table_lv2.c, delim_scan_4.delim_scan_2_b, delim_scan_4.outer_table_a, delim_scan_4.outer_table_c [a:UInt32, b:UInt32, c:UInt32, delim_scan_2_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2850-
Inner Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, delim_scan_2_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2851-
TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]
2852-
SubqueryAlias: delim_scan_4 [delim_scan_2_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2853-
DelimGet: delim_scan_2.b, outer_table.a, outer_table.c [delim_scan_2_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2854-
SubqueryAlias: delim_scan_3 [delim_scan_2_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2855-
DelimGet: delim_scan_2.b, outer_table.a, outer_table.c [delim_scan_2_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2847+
Filter: inner_table_lv1.c = delim_scan_4.outer_table_c AND __scalar_sq_1.output = Int32(1) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, outer_table_c:UInt32;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N, __scalar_sq_1.output:Int32;N]
2848+
Projection: inner_table_lv1.a, inner_table_lv1.b, inner_table_lv1.c, delim_scan_2.outer_table_a, delim_scan_2.outer_table_c, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, delim_scan_4.outer_table_c, delim_scan_4.outer_table_a, delim_scan_4.inner_table_lv1_b, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END AS __scalar_sq_1.output [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, outer_table_c:UInt32;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N, __scalar_sq_1.output:Int32;N]
2849+
Left Join: Filter: inner_table_lv1.b IS NOT DISTINCT FROM delim_scan_4.b [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N, CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32;N, outer_table_c:UInt32;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N]
2850+
Inner Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2851+
TableScan: inner_table_lv1 [a:UInt32, b:UInt32, c:UInt32]
2852+
SubqueryAlias: delim_scan_2 [outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2853+
DelimGet: outer_table.a, outer_table.c [outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2854+
Projection: CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END, delim_scan_4.outer_table_c, delim_scan_4.outer_table_a, delim_scan_4.inner_table_lv1_b [CASE WHEN count(inner_table_lv2.a) IS NULL THEN Int32(0) ELSE count(inner_table_lv2.a) END:Int32, outer_table_c:UInt32;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N]
2855+
Projection: count(inner_table_lv2.a), delim_scan_4.outer_table_c, delim_scan_4.outer_table_a, delim_scan_4.inner_table_lv1_b [count(inner_table_lv2.a):Int64, outer_table_c:UInt32;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N]
2856+
Inner Join: Filter: delim_scan_4.inner_table_lv1_b IS NOT DISTINCT FROM delim_scan_3.inner_table_lv1_b AND delim_scan_4.outer_table_a IS NOT DISTINCT FROM delim_scan_3.outer_table_a AND delim_scan_4.outer_table_c IS NOT DISTINCT FROM delim_scan_3.outer_table_c [count(inner_table_lv2.a):Int64, outer_table_c:UInt32;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2857+
Projection: count(inner_table_lv2.a), delim_scan_4.outer_table_c, delim_scan_4.outer_table_a, delim_scan_4.inner_table_lv1_b [count(inner_table_lv2.a):Int64, outer_table_c:UInt32;N, outer_table_a:UInt32;N, inner_table_lv1_b:UInt32;N]
2858+
Aggregate: groupBy=[[delim_scan_4.inner_table_lv1_b, delim_scan_4.outer_table_a, delim_scan_4.outer_table_c]], aggr=[[count(inner_table_lv2.a)]] [inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N, count(inner_table_lv2.a):Int64]
2859+
Filter: inner_table_lv2.a = delim_scan_4.outer_table_a AND inner_table_lv2.b = delim_scan_4.inner_table_lv1_b [a:UInt32, b:UInt32, c:UInt32, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2860+
Inner Join: Filter: Boolean(true) [a:UInt32, b:UInt32, c:UInt32, inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2861+
TableScan: inner_table_lv2 [a:UInt32, b:UInt32, c:UInt32]
2862+
SubqueryAlias: delim_scan_4 [inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2863+
DelimGet: inner_table_lv1.b, outer_table.a, outer_table.c [inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2864+
SubqueryAlias: delim_scan_3 [inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
2865+
DelimGet: inner_table_lv1.b, outer_table.a, outer_table.c [inner_table_lv1_b:UInt32;N, outer_table_a:UInt32;N, outer_table_c:UInt32;N]
28562866
SubqueryAlias: delim_scan_1 [outer_table_a:UInt32;N, outer_table_c:UInt32;N]
28572867
DelimGet: outer_table.a, outer_table.c [outer_table_a:UInt32;N, outer_table_c:UInt32;N]
28582868
");

0 commit comments

Comments
 (0)