Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit da98e1e

Browse files
committedFeb 14, 2024·
Propagate aggregate equivalences
1 parent 7562bc6 commit da98e1e

File tree

2 files changed

+81
-44
lines changed

2 files changed

+81
-44
lines changed
 

‎src/transform/src/equivalence_propagation.rs

+80-43
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,23 @@ impl Analysis for EquivalencePropagation {
186186
}
187187
MirRelationExpr::Reduce {
188188
group_key,
189-
aggregates: _,
189+
aggregates,
190190
..
191191
} => {
192+
// All of the output columns are the result of applying expressions to input columns,
193+
// and we require a game plan to emerge with any equivalences about the output columns.
194+
//
195+
// Our plan is to (pretend to) extend the input with the expressions in `group_key`,
196+
// and the apply a projection that retains only these columns, preserving any equivalences
197+
// among the grouping columns.
198+
// Additionally, for each "equality preserving" aggregation (MIN, MAX, ANY, ALL) we will
199+
// append their expression to those of the grouping keys, project to the keys and the
200+
// aggregate expression, and permute it in to place (as it may not follow the grouping keys).
201+
// We will finally stitch all of these classes together.
202+
//
203+
// Importantly, we handle the aggregate expressions in isolation, as equivalences that
204+
// may hold (e.g. `col1 = col2`) may not hold after aggregation (`MIN(col1) v MAX(col2)`).
205+
192206
let input_arity = depends.results::<Arity>().unwrap()[results.len() - 1];
193207
let mut equivalences = results.last().unwrap().clone();
194208
// Introduce keys column equivalences as a map, then project to them as a projection.
@@ -197,13 +211,30 @@ impl Analysis for EquivalencePropagation {
197211
.classes
198212
.push(vec![MirScalarExpr::Column(input_arity + pos), expr.clone()]);
199213
}
200-
// TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
201-
// They also pass through equivalences of them and other constant columns (e.g. key columns).
202-
// However, it is not correct to simply project onto these columns, as relationships amongst
203-
// aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
204-
// TODO: COUNT ensures a non-null value.
205-
equivalences.project(input_arity..(input_arity + group_key.len()));
206-
equivalences
214+
215+
// Accumulate the equivalence classes for the output of `Reduce`.
216+
let mut result = equivalences.clone();
217+
result.project(input_arity..(input_arity + group_key.len()));
218+
219+
// Introduce aggregate expressions as columns and equate the columns and expressions.
220+
for (pos, aggr) in aggregates.iter().enumerate() {
221+
if preserves_equivalences(&aggr.func) {
222+
let mut equivs = equivalences.clone();
223+
equivs.classes.push(vec![
224+
MirScalarExpr::Column(input_arity + group_key.len()),
225+
aggr.expr.clone(),
226+
]);
227+
equivs.project(input_arity..(input_arity + group_key.len() + 1));
228+
let permutation = (0..group_key.len())
229+
.chain(Some(group_key.len() + pos))
230+
.collect::<Vec<_>>();
231+
equivs.permute(&permutation[..]);
232+
// Fold the classes in to
233+
result.classes.extend(equivs.classes.drain(..));
234+
}
235+
}
236+
237+
result
207238
}
208239
MirRelationExpr::TopK { .. } => results.last().unwrap().clone(),
209240
MirRelationExpr::Negate { .. } => results.last().unwrap().clone(),
@@ -775,6 +806,8 @@ impl EquivalenceClasses {
775806
}
776807

777808
/// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
809+
///
810+
/// This is simpler than `project`, in that it only relabels column references and it does not la
778811
fn permute(&mut self, permutation: &[usize]) {
779812
for class in self.classes.iter_mut() {
780813
for expr in class.iter_mut() {
@@ -933,38 +966,42 @@ impl EquivalenceClasses {
933966
}
934967
}
935968

936-
// fn preserves_equivalences(func: &AggregateFunc) -> bool {
937-
// match func {
938-
// AggregateFunc::MaxInt16
939-
// | AggregateFunc::MaxInt32
940-
// | AggregateFunc::MaxInt64
941-
// | AggregateFunc::MaxUInt16
942-
// | AggregateFunc::MaxUInt32
943-
// | AggregateFunc::MaxUInt64
944-
// | AggregateFunc::MaxMzTimestamp
945-
// | AggregateFunc::MaxFloat32
946-
// | AggregateFunc::MaxFloat64
947-
// | AggregateFunc::MaxBool
948-
// | AggregateFunc::MaxString
949-
// | AggregateFunc::MaxDate
950-
// | AggregateFunc::MaxTimestamp
951-
// | AggregateFunc::MaxTimestampTz
952-
// | AggregateFunc::MinInt16
953-
// | AggregateFunc::MinInt32
954-
// | AggregateFunc::MinInt64
955-
// | AggregateFunc::MinUInt16
956-
// | AggregateFunc::MinUInt32
957-
// | AggregateFunc::MinUInt64
958-
// | AggregateFunc::MinMzTimestamp
959-
// | AggregateFunc::MinFloat32
960-
// | AggregateFunc::MinFloat64
961-
// | AggregateFunc::MinBool
962-
// | AggregateFunc::MinString
963-
// | AggregateFunc::MinDate
964-
// | AggregateFunc::MinTimestamp
965-
// | AggregateFunc::MinTimestampTz
966-
// | AggregateFunc::Any
967-
// | AggregateFunc::All => true,
968-
// _ => false,
969-
// }
970-
// }
969+
use mz_expr::AggregateFunc;
970+
971+
/// True iff the aggregate function produces an instance of its input expression,
972+
/// and consequently preserves any equivalence involving
973+
fn preserves_equivalences(func: &AggregateFunc) -> bool {
974+
match func {
975+
AggregateFunc::MaxInt16
976+
| AggregateFunc::MaxInt32
977+
| AggregateFunc::MaxInt64
978+
| AggregateFunc::MaxUInt16
979+
| AggregateFunc::MaxUInt32
980+
| AggregateFunc::MaxUInt64
981+
| AggregateFunc::MaxMzTimestamp
982+
| AggregateFunc::MaxFloat32
983+
| AggregateFunc::MaxFloat64
984+
| AggregateFunc::MaxBool
985+
| AggregateFunc::MaxString
986+
| AggregateFunc::MaxDate
987+
| AggregateFunc::MaxTimestamp
988+
| AggregateFunc::MaxTimestampTz
989+
| AggregateFunc::MinInt16
990+
| AggregateFunc::MinInt32
991+
| AggregateFunc::MinInt64
992+
| AggregateFunc::MinUInt16
993+
| AggregateFunc::MinUInt32
994+
| AggregateFunc::MinUInt64
995+
| AggregateFunc::MinMzTimestamp
996+
| AggregateFunc::MinFloat32
997+
| AggregateFunc::MinFloat64
998+
| AggregateFunc::MinBool
999+
| AggregateFunc::MinString
1000+
| AggregateFunc::MinDate
1001+
| AggregateFunc::MinTimestamp
1002+
| AggregateFunc::MinTimestampTz
1003+
| AggregateFunc::Any
1004+
| AggregateFunc::All => true,
1005+
_ => false,
1006+
}
1007+
}

‎test/testdrive/mz-arrangement-sharing.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,7 @@ ReduceMinsMaxes
705705
"Arrange ReduceMinsMaxes" 1
706706
"ArrangeAccumulable [val: empty]" 3
707707
"ArrangeBy[[Column(0), Column(2)]]" 1
708-
"ArrangeBy[[Column(0)]] [val: empty]" 2
708+
"ArrangeBy[[Column(0)]] [val: empty]" 1
709709
"ArrangeBy[[Column(0)]]-errors" 20
710710
"ArrangeBy[[Column(0)]]" 39
711711
"ArrangeBy[[Column(1), Column(3)]]" 2

0 commit comments

Comments
 (0)
Please sign in to comment.