@@ -186,9 +186,23 @@ impl Analysis for EquivalencePropagation {
186
186
}
187
187
MirRelationExpr :: Reduce {
188
188
group_key,
189
- aggregates : _ ,
189
+ aggregates,
190
190
..
191
191
} => {
192
+ // All of the output columns are the result of applying expressions to input columns,
193
+ // and we require a game plan to emerge with any equivalences about the output columns.
194
+ //
195
+ // Our plan is to (pretend to) extend the input with the expressions in `group_key`,
196
+ // and the apply a projection that retains only these columns, preserving any equivalences
197
+ // among the grouping columns.
198
+ // Additionally, for each "equality preserving" aggregation (MIN, MAX, ANY, ALL) we will
199
+ // append their expression to those of the grouping keys, project to the keys and the
200
+ // aggregate expression, and permute it in to place (as it may not follow the grouping keys).
201
+ // We will finally stitch all of these classes together.
202
+ //
203
+ // Importantly, we handle the aggregate expressions in isolation, as equivalences that
204
+ // may hold (e.g. `col1 = col2`) may not hold after aggregation (`MIN(col1) v MAX(col2)`).
205
+
192
206
let input_arity = depends. results :: < Arity > ( ) . unwrap ( ) [ results. len ( ) - 1 ] ;
193
207
let mut equivalences = results. last ( ) . unwrap ( ) . clone ( ) ;
194
208
// Introduce keys column equivalences as a map, then project to them as a projection.
@@ -197,13 +211,30 @@ impl Analysis for EquivalencePropagation {
197
211
. classes
198
212
. push ( vec ! [ MirScalarExpr :: Column ( input_arity + pos) , expr. clone( ) ] ) ;
199
213
}
200
- // TODO: MIN, MAX, ANY, ALL aggregates pass through all certain properties of their columns.
201
- // They also pass through equivalences of them and other constant columns (e.g. key columns).
202
- // However, it is not correct to simply project onto these columns, as relationships amongst
203
- // aggregate columns may no longer be preserved. MAX(col) != MIN(col) even though col = col.
204
- // TODO: COUNT ensures a non-null value.
205
- equivalences. project ( input_arity..( input_arity + group_key. len ( ) ) ) ;
206
- equivalences
214
+
215
+ // Accumulate the equivalence classes for the output of `Reduce`.
216
+ let mut result = equivalences. clone ( ) ;
217
+ result. project ( input_arity..( input_arity + group_key. len ( ) ) ) ;
218
+
219
+ // Introduce aggregate expressions as columns and equate the columns and expressions.
220
+ for ( pos, aggr) in aggregates. iter ( ) . enumerate ( ) {
221
+ if preserves_equivalences ( & aggr. func ) {
222
+ let mut equivs = equivalences. clone ( ) ;
223
+ equivs. classes . push ( vec ! [
224
+ MirScalarExpr :: Column ( input_arity + group_key. len( ) ) ,
225
+ aggr. expr. clone( ) ,
226
+ ] ) ;
227
+ equivs. project ( input_arity..( input_arity + group_key. len ( ) + 1 ) ) ;
228
+ let permutation = ( 0 ..group_key. len ( ) )
229
+ . chain ( Some ( group_key. len ( ) + pos) )
230
+ . collect :: < Vec < _ > > ( ) ;
231
+ equivs. permute ( & permutation[ ..] ) ;
232
+ // Fold the classes in to
233
+ result. classes . extend ( equivs. classes . drain ( ..) ) ;
234
+ }
235
+ }
236
+
237
+ result
207
238
}
208
239
MirRelationExpr :: TopK { .. } => results. last ( ) . unwrap ( ) . clone ( ) ,
209
240
MirRelationExpr :: Negate { .. } => results. last ( ) . unwrap ( ) . clone ( ) ,
@@ -775,6 +806,8 @@ impl EquivalenceClasses {
775
806
}
776
807
777
808
/// Permutes each expression, looking up each column reference in `permutation` and replacing with what it finds.
809
+ ///
810
+ /// This is simpler than `project`, in that it only relabels column references and it does not la
778
811
fn permute ( & mut self , permutation : & [ usize ] ) {
779
812
for class in self . classes . iter_mut ( ) {
780
813
for expr in class. iter_mut ( ) {
@@ -933,38 +966,42 @@ impl EquivalenceClasses {
933
966
}
934
967
}
935
968
936
- // fn preserves_equivalences(func: &AggregateFunc) -> bool {
937
- // match func {
938
- // AggregateFunc::MaxInt16
939
- // | AggregateFunc::MaxInt32
940
- // | AggregateFunc::MaxInt64
941
- // | AggregateFunc::MaxUInt16
942
- // | AggregateFunc::MaxUInt32
943
- // | AggregateFunc::MaxUInt64
944
- // | AggregateFunc::MaxMzTimestamp
945
- // | AggregateFunc::MaxFloat32
946
- // | AggregateFunc::MaxFloat64
947
- // | AggregateFunc::MaxBool
948
- // | AggregateFunc::MaxString
949
- // | AggregateFunc::MaxDate
950
- // | AggregateFunc::MaxTimestamp
951
- // | AggregateFunc::MaxTimestampTz
952
- // | AggregateFunc::MinInt16
953
- // | AggregateFunc::MinInt32
954
- // | AggregateFunc::MinInt64
955
- // | AggregateFunc::MinUInt16
956
- // | AggregateFunc::MinUInt32
957
- // | AggregateFunc::MinUInt64
958
- // | AggregateFunc::MinMzTimestamp
959
- // | AggregateFunc::MinFloat32
960
- // | AggregateFunc::MinFloat64
961
- // | AggregateFunc::MinBool
962
- // | AggregateFunc::MinString
963
- // | AggregateFunc::MinDate
964
- // | AggregateFunc::MinTimestamp
965
- // | AggregateFunc::MinTimestampTz
966
- // | AggregateFunc::Any
967
- // | AggregateFunc::All => true,
968
- // _ => false,
969
- // }
970
- // }
969
+ use mz_expr:: AggregateFunc ;
970
+
971
+ /// True iff the aggregate function produces an instance of its input expression,
972
+ /// and consequently preserves any equivalence involving
973
+ fn preserves_equivalences ( func : & AggregateFunc ) -> bool {
974
+ match func {
975
+ AggregateFunc :: MaxInt16
976
+ | AggregateFunc :: MaxInt32
977
+ | AggregateFunc :: MaxInt64
978
+ | AggregateFunc :: MaxUInt16
979
+ | AggregateFunc :: MaxUInt32
980
+ | AggregateFunc :: MaxUInt64
981
+ | AggregateFunc :: MaxMzTimestamp
982
+ | AggregateFunc :: MaxFloat32
983
+ | AggregateFunc :: MaxFloat64
984
+ | AggregateFunc :: MaxBool
985
+ | AggregateFunc :: MaxString
986
+ | AggregateFunc :: MaxDate
987
+ | AggregateFunc :: MaxTimestamp
988
+ | AggregateFunc :: MaxTimestampTz
989
+ | AggregateFunc :: MinInt16
990
+ | AggregateFunc :: MinInt32
991
+ | AggregateFunc :: MinInt64
992
+ | AggregateFunc :: MinUInt16
993
+ | AggregateFunc :: MinUInt32
994
+ | AggregateFunc :: MinUInt64
995
+ | AggregateFunc :: MinMzTimestamp
996
+ | AggregateFunc :: MinFloat32
997
+ | AggregateFunc :: MinFloat64
998
+ | AggregateFunc :: MinBool
999
+ | AggregateFunc :: MinString
1000
+ | AggregateFunc :: MinDate
1001
+ | AggregateFunc :: MinTimestamp
1002
+ | AggregateFunc :: MinTimestampTz
1003
+ | AggregateFunc :: Any
1004
+ | AggregateFunc :: All => true ,
1005
+ _ => false ,
1006
+ }
1007
+ }
0 commit comments