Skip to content

Commit

Permalink
refactor(optimizer): cleanup LogicalAgg::prune_col (risingwavelabs#3663)
Browse files Browse the repository at this point in the history
* refactor(optimizer): cleanup LogicalAgg::prune_col

* planner test of prune-filter bug

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
xiangjinwu and mergify[bot] authored Jul 7, 2022
1 parent 82ef7ec commit 078c2f6
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 67 deletions.
80 changes: 28 additions & 52 deletions src/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,11 +480,6 @@ impl LogicalAgg {
self.o2i_col_mapping().inverse()
}

/// get the Mapping of columnIndex from input column index to out column index
pub fn i2o_col_mapping_with_required_out(&self, required: &FixedBitSet) -> ColIndexMapping {
self.o2i_col_mapping().inverse_with_required(required)
}

fn derive_schema(
input: &Schema,
group_key: &[usize],
Expand Down Expand Up @@ -551,25 +546,15 @@ impl LogicalAgg {
pub fn decompose(self) -> (Vec<PlanAggCall>, Vec<usize>, PlanRef) {
(self.agg_calls, self.group_key, self.input)
}
}

impl PlanTreeNodeUnary for LogicalAgg {
fn input(&self) -> PlanRef {
self.input.clone()
}

fn clone_with_input(&self, input: PlanRef) -> Self {
Self::new(self.agg_calls().to_vec(), self.group_key().to_vec(), input)
}

#[must_use]
fn rewrite_with_input(
fn rewrite_with_input_agg(
&self,
input: PlanRef,
agg_calls: &[PlanAggCall],
mut input_col_change: ColIndexMapping,
) -> (Self, ColIndexMapping) {
let agg_calls = self
.agg_calls
) -> Self {
let agg_calls = agg_calls
.iter()
.cloned()
.map(|mut agg_call| {
Expand All @@ -586,7 +571,26 @@ impl PlanTreeNodeUnary for LogicalAgg {
.cloned()
.map(|key| input_col_change.map(key))
.collect();
let agg = Self::new(agg_calls, group_key, input);
Self::new(agg_calls, group_key, input)
}
}

impl PlanTreeNodeUnary for LogicalAgg {
fn input(&self) -> PlanRef {
self.input.clone()
}

fn clone_with_input(&self, input: PlanRef) -> Self {
Self::new(self.agg_calls().to_vec(), self.group_key().to_vec(), input)
}

#[must_use]
fn rewrite_with_input(
&self,
input: PlanRef,
input_col_change: ColIndexMapping,
) -> (Self, ColIndexMapping) {
let agg = self.rewrite_with_input_agg(input, &self.agg_calls, input_col_change);
// change the input columns index will not change the output column index
let out_col_change = ColIndexMapping::identity(agg.schema().len());
(agg, out_col_change)
Expand All @@ -606,14 +610,6 @@ impl fmt::Display for LogicalAgg {

impl ColPrunable for LogicalAgg {
fn prune_col(&self, required_cols: &[usize]) -> PlanRef {
let upstream_required_cols = {
let mapping = self.o2i_col_mapping();
FixedBitSet::from_iter(
required_cols
.iter()
.filter_map(|&output_idx| mapping.try_map(output_idx)),
)
};
let group_key_required_cols = FixedBitSet::from_iter(self.group_key.iter().copied());

let (agg_call_required_cols, agg_calls) = {
Expand All @@ -637,38 +633,18 @@ impl ColPrunable for LogicalAgg {
};

let input_required_cols = {
let mut tmp: FixedBitSet = upstream_required_cols;
let mut tmp = FixedBitSet::with_capacity(self.input.schema().len());
tmp.union_with(&group_key_required_cols);
tmp.union_with(&agg_call_required_cols);
tmp.ones().collect_vec()
};
let mapping = ColIndexMapping::with_remaining_columns(
let input_col_change = ColIndexMapping::with_remaining_columns(
&input_required_cols,
self.input().schema().len(),
);
let agg = {
let agg_calls = agg_calls
.iter()
.cloned()
.map(|mut agg_call| {
agg_call
.inputs
.iter_mut()
.for_each(|i| *i = InputRef::new(mapping.map(i.index()), i.return_type()));
agg_call
})
.collect();
let group_key = self
.group_key
.iter()
.cloned()
.map(|key| mapping.map(key))
.collect();
LogicalAgg::new(
agg_calls,
group_key,
self.input.prune_col(&input_required_cols),
)
let input = self.input.prune_col(&input_required_cols);
self.rewrite_with_input_agg(input, &agg_calls, input_col_change)
};
let new_output_cols = {
// group key were never pruned or even re-ordered in current impl
Expand Down
15 changes: 0 additions & 15 deletions src/frontend/src/utils/column_index_mapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,21 +218,6 @@ impl ColIndexMapping {
Self::with_target_size(map, self.source_size())
}

/// inverse the mapping with required columns in the source, if a target corresponds more than
/// one source, it will choose the required columns.
#[must_use]
pub fn inverse_with_required(&self, required: &FixedBitSet) -> Self {
let mut map = vec![None; self.target_size()];
for (src, dst) in self.mapping_pairs() {
if let Some(other_src) = map[dst] && required.contains(other_src) {
// do nothing
} else {
map[dst] = Some(src);
}
}
Self::with_target_size(map, self.source_size())
}

/// return iter of (src, dst) order by src
pub fn mapping_pairs(&self) -> impl Iterator<Item = (usize, usize)> + '_ {
self.map
Expand Down
15 changes: 15 additions & 0 deletions src/frontend/test_runner/tests/testdata/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,18 @@
create table t(a int, b int);
select abs(a) FILTER (WHERE a > 0) AS avga from t;
binder_error: 'Invalid input syntax: filter clause is only allowed in aggregation functions, but `abs` is not an aggregation function'
- sql: |
/* prune column before filter */
create table t(v1 int, v2 int);
with sub(a, b) as (select min(v1), sum(v2) filter (where v2 < 5) from t) select b from sub;
batch_plan: |
BatchSimpleAgg { aggs: [sum($0)] }
BatchExchange { order: [], dist: Single }
BatchSimpleAgg { aggs: [sum($0) filter(($0 < 5:Int32))] }
BatchScan { table: t, columns: [v2] }
stream_plan: |
StreamMaterialize { columns: [agg#0(hidden), b], pk_columns: [] }
StreamGlobalSimpleAgg { aggs: [sum($0), sum($1)] }
StreamExchange { dist: Single }
StreamLocalSimpleAgg { aggs: [count, sum($0) filter(($0 < 5:Int32))] }
StreamTableScan { table: t, columns: [v2, _row_id], pk_indices: [1] }

0 comments on commit 078c2f6

Please sign in to comment.