Skip to content

Commit

Permalink
Eliminate some MRE::arity() calls
Browse files Browse the repository at this point in the history
  • Loading branch information
ggevay committed Dec 20, 2024
1 parent bcde0e3 commit a26a94e
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
5 changes: 3 additions & 2 deletions src/compute-types/src/plan/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,6 @@ impl Context {
equivalences,
implementation,
} => {
let input_mapper = JoinInputMapper::new(inputs);

// Plan each of the join inputs independently.
// The `plans` get surfaced upwards, and the `input_keys` should
// be used as part of join planning / to validate the existing
Expand All @@ -564,6 +562,9 @@ impl Context {
input_keys.push(keys);
}

let input_mapper =
JoinInputMapper::new_from_input_arities(input_arities.iter().copied());

// Extract temporal predicates as joins cannot currently absorb them.
let (plan, missing) = match implementation {
IndexedFilter(_coll_id, _idx_id, key, _val) => {
Expand Down
2 changes: 2 additions & 0 deletions src/expr/src/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ impl MapFilterProject {
let (mfp, expr) = Self::extract_from_expression(input);
(mfp.project(outputs.iter().cloned()), expr)
}
// TODO: The recursion is quadratic in the number of Map/Filter/Project operators due to
// this call to `arity()`.
x => (Self::new(x.arity()), x),
}
}
Expand Down
9 changes: 4 additions & 5 deletions src/expr/src/relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1798,22 +1798,21 @@ impl MirRelationExpr {
.unzip();
assert_eq!(keys_and_values.arity() - self.arity(), data.len());
self.let_in(id_gen, |_id_gen, get_keys| {
let get_keys_arity = get_keys.arity();
Ok(MirRelationExpr::join(
vec![
// all the missing keys (with count 1)
keys_and_values
.distinct_by((0..get_keys.arity()).collect())
.distinct_by((0..get_keys_arity).collect())
.negate()
.union(get_keys.clone().distinct()),
// join with keys to get the correct counts
get_keys.clone(),
],
(0..get_keys.arity())
.map(|i| vec![(0, i), (1, i)])
.collect(),
(0..get_keys_arity).map(|i| vec![(0, i), (1, i)]).collect(),
)
// get rid of the extra copies of columns from keys
.project((0..get_keys.arity()).collect())
.project((0..get_keys_arity).collect())
// This join is logically equivalent to
// `.map(<default_expr>)`, but using a join allows for
// potential predicate pushdown and elision in the
Expand Down
5 changes: 4 additions & 1 deletion src/sql/src/plan/lowering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1574,12 +1574,14 @@ impl HirScalarExpr {
let inner_arity = get_inner.arity();
let mut total_arity = inner_arity;
let mut join_inputs = vec![get_inner];
let mut join_input_arities = vec![inner_arity];
for (expr, subquery) in subqueries.into_iter() {
// Avoid lowering duplicated subqueries
if !subquery_map.contains_key(&expr) {
let subquery_arity = subquery.arity();
assert_eq!(subquery_arity, inner_arity + 1);
join_inputs.push(subquery);
join_input_arities.push(subquery_arity);
total_arity += subquery_arity;

// Column with the value of the subquery
Expand All @@ -1589,7 +1591,8 @@ impl HirScalarExpr {
// Each subquery projects all the columns of the outer context (distinct_inner)
// plus 1 column, containing the result of the subquery. Those columns must be
// joined with the outer/main relation (get_inner).
let input_mapper = mz_expr::JoinInputMapper::new(&join_inputs);
let input_mapper =
mz_expr::JoinInputMapper::new_from_input_arities(join_input_arities);
let equivalences = (0..inner_arity)
.map(|col| {
join_inputs
Expand Down
3 changes: 2 additions & 1 deletion src/transform/src/fusion/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ impl Project {
*outputs = outputs.iter().map(|i| outputs2[*i]).collect();
**input = inner.take_dangerous();
}
if outputs.iter().enumerate().all(|(a, b)| a == *b) && outputs.len() == input.arity() {
let input_arity = input.arity();
if outputs.iter().enumerate().all(|(a, b)| a == *b) && outputs.len() == input_arity {
*relation = input.take_dangerous();
}
}
Expand Down

0 comments on commit a26a94e

Please sign in to comment.