Skip to content

Commit

Permalink
feat(optimizer): Implement ProjectJoin rule (#4385)
Browse files Browse the repository at this point in the history
* success

* success

* fix merge main
  • Loading branch information
jon-chuang authored Aug 3, 2022
1 parent 072b6e0 commit ba66d65
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 248 deletions.
3 changes: 3 additions & 0 deletions src/frontend/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ impl PlanRoot {
// merge should be applied before eliminate
ProjectMergeRule::create(),
ProjectEliminateRule::create(),
// project-join merge should be applied after merge
// and eliminate
ProjectJoinRule::create(),
],
ApplyOrder::BottomUp,
);
Expand Down
1 change: 0 additions & 1 deletion src/frontend/src/optimizer/plan_node/logical_expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ impl PlanTreeNodeUnary for LogicalExpand {
})
.collect_vec();
let (mut map, new_input_col_num) = input_col_change.into_parts();
assert_eq!(new_input_col_num, input.schema().len());
map.push(Some(new_input_col_num));

(Self::new(input, column_subsets), ColIndexMapping::new(map))
Expand Down
10 changes: 0 additions & 10 deletions src/frontend/src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,6 @@ impl fmt::Display for LogicalJoin {
}
}

fn has_duplicate_index(indices: &[usize]) -> bool {
for i in 1..indices.len() {
if indices[i..].contains(&indices[i - 1]) {
return true;
}
}
false
}

impl LogicalJoin {
pub(crate) fn new(left: PlanRef, right: PlanRef, join_type: JoinType, on: Condition) -> Self {
let out_column_num =
Expand All @@ -120,7 +111,6 @@ impl LogicalJoin {
on: Condition,
output_indices: Vec<usize>,
) -> Self {
assert!(!has_duplicate_index(&output_indices));
let ctx = left.ctx();
let schema = Self::derive_schema(left.schema(), right.schema(), join_type, &output_indices);
let pk_indices = Self::derive_pk(
Expand Down
24 changes: 21 additions & 3 deletions src/frontend/src/optimizer/rule/project_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,30 @@
// limitations under the License.

use super::super::plan_node::*;
use super::Rule;
use super::{BoxedRule, Rule};
pub struct ProjectJoinRule {}

impl ProjectJoinRule {
pub fn create() -> BoxedRule {
Box::new(Self {})
}
}

impl Rule for ProjectJoinRule {
fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
let project = plan.as_logical_project()?;
let _join = project.input().clone().as_logical_join()?;
todo!()
let input = project.input();
let join = input.as_logical_join()?;
if project.exprs().iter().all(|e| e.as_input_ref().is_some()) {
let out_indices = project
.exprs()
.iter()
.map(|e| e.as_input_ref().unwrap().index());
let mapping = join.o2i_col_mapping();
let new_output_indices = out_indices.map(|idx| mapping.map(idx)).collect();
Some(join.clone_with_output_indices(new_output_indices).into())
} else {
None
}
}
}
19 changes: 9 additions & 10 deletions src/frontend/test_runner/tests/testdata/join.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,15 @@
select * from ab join bc using(b) join ca using(c);
batch_plan: |
BatchExchange { order: [], dist: Single }
BatchProject { exprs: [bc.c, ab.b, ab.a, ca.a] }
BatchHashJoin { type: Inner, predicate: bc.c = ca.c, output: [ab.a, ab.b, bc.c, ca.a] }
BatchExchange { order: [], dist: HashShard(bc.c) }
BatchHashJoin { type: Inner, predicate: ab.b = bc.b, output: [ab.a, ab.b, bc.c] }
BatchExchange { order: [], dist: HashShard(ab.b) }
BatchScan { table: ab, columns: [ab.a, ab.b], distribution: SomeShard }
BatchExchange { order: [], dist: HashShard(bc.b) }
BatchScan { table: bc, columns: [bc.b, bc.c], distribution: SomeShard }
BatchExchange { order: [], dist: HashShard(ca.c) }
BatchScan { table: ca, columns: [ca.c, ca.a], distribution: SomeShard }
BatchHashJoin { type: Inner, predicate: bc.c = ca.c, output: [bc.c, ab.b, ab.a, ca.a] }
BatchExchange { order: [], dist: HashShard(bc.c) }
BatchHashJoin { type: Inner, predicate: ab.b = bc.b, output: [ab.a, ab.b, bc.c] }
BatchExchange { order: [], dist: HashShard(ab.b) }
BatchScan { table: ab, columns: [ab.a, ab.b], distribution: SomeShard }
BatchExchange { order: [], dist: HashShard(bc.b) }
BatchScan { table: bc, columns: [bc.b, bc.c], distribution: SomeShard }
BatchExchange { order: [], dist: HashShard(ca.c) }
BatchScan { table: ca, columns: [ca.c, ca.a], distribution: SomeShard }
- sql: |
/* Only push to left */
create table t1 (v1 int, v2 int);
Expand Down
34 changes: 16 additions & 18 deletions src/frontend/test_runner/tests/testdata/nexmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,27 +93,25 @@
A.category = 10 and (P.state = 'or' OR P.state = 'id' OR P.state = 'ca');
batch_plan: |
BatchExchange { order: [], dist: Single }
BatchProject { exprs: [person.name, person.city, person.state, auction.id] }
BatchHashJoin { type: Inner, predicate: auction.seller = person.id, output: [auction.id, person.name, person.city, person.state] }
BatchExchange { order: [], dist: HashShard(auction.seller) }
BatchProject { exprs: [auction.id, auction.seller] }
BatchFilter { predicate: (auction.category = 10:Int32) }
BatchScan { table: auction, columns: [auction.id, auction.seller, auction.category], distribution: SomeShard }
BatchExchange { order: [], dist: HashShard(person.id) }
BatchFilter { predicate: (((person.state = 'or':Varchar) OR (person.state = 'id':Varchar)) OR (person.state = 'ca':Varchar)) }
BatchScan { table: person, columns: [person.id, person.name, person.city, person.state], distribution: SomeShard }
BatchHashJoin { type: Inner, predicate: auction.seller = person.id, output: [person.name, person.city, person.state, auction.id] }
BatchExchange { order: [], dist: HashShard(auction.seller) }
BatchProject { exprs: [auction.id, auction.seller] }
BatchFilter { predicate: (auction.category = 10:Int32) }
BatchScan { table: auction, columns: [auction.id, auction.seller, auction.category], distribution: SomeShard }
BatchExchange { order: [], dist: HashShard(person.id) }
BatchFilter { predicate: (((person.state = 'or':Varchar) OR (person.state = 'id':Varchar)) OR (person.state = 'ca':Varchar)) }
BatchScan { table: person, columns: [person.id, person.name, person.city, person.state], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [name, city, state, id, auction._row_id(hidden), person._row_id(hidden)], pk_columns: [auction._row_id, person._row_id] }
StreamExchange { dist: HashShard(auction._row_id, person._row_id) }
StreamProject { exprs: [person.name, person.city, person.state, auction.id, auction._row_id, person._row_id] }
StreamHashJoin { type: Inner, predicate: auction.seller = person.id, output: [auction.id, person.name, person.city, person.state, auction._row_id, person._row_id] }
StreamExchange { dist: HashShard(auction.seller) }
StreamProject { exprs: [auction.id, auction.seller, auction._row_id] }
StreamFilter { predicate: (auction.category = 10:Int32) }
StreamTableScan { table: auction, columns: [auction.id, auction.seller, auction._row_id, auction.category], pk: [auction._row_id], distribution: HashShard(auction._row_id) }
StreamExchange { dist: HashShard(person.id) }
StreamFilter { predicate: (((person.state = 'or':Varchar) OR (person.state = 'id':Varchar)) OR (person.state = 'ca':Varchar)) }
StreamTableScan { table: person, columns: [person.id, person.name, person.city, person.state, person._row_id], pk: [person._row_id], distribution: HashShard(person._row_id) }
StreamHashJoin { type: Inner, predicate: auction.seller = person.id, output: [person.name, person.city, person.state, auction.id, auction._row_id, person._row_id] }
StreamExchange { dist: HashShard(auction.seller) }
StreamProject { exprs: [auction.id, auction.seller, auction._row_id] }
StreamFilter { predicate: (auction.category = 10:Int32) }
StreamTableScan { table: auction, columns: [auction.id, auction.seller, auction._row_id, auction.category], pk: [auction._row_id], distribution: HashShard(auction._row_id) }
StreamExchange { dist: HashShard(person.id) }
StreamFilter { predicate: (((person.state = 'or':Varchar) OR (person.state = 'id':Varchar)) OR (person.state = 'ca':Varchar)) }
StreamTableScan { table: person, columns: [person.id, person.name, person.city, person.state, person._row_id], pk: [person._row_id], distribution: HashShard(person._row_id) }
- id: nexmark_q4
before:
- create_tables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,9 @@
LogicalScan { table: t1, columns: [t1._row_id, t1.x, t1.y] }
optimized_logical_plan: |
LogicalTopN { order: "[t2.x ASC]", limit: 100, offset: 0 }
LogicalProject { exprs: [t2.x, t1.x, t2.x] }
LogicalJoin { type: LeftOuter, on: (t1.y = t2.y), output: [t2.x, t1.x] }
LogicalScan { table: t2, output_columns: [t2.x, t2.y], required_columns: [x, y], predicate: (t2.x > 100:Int32) }
LogicalScan { table: t1, columns: [t1.x, t1.y] }
LogicalJoin { type: LeftOuter, on: (t1.y = t2.y), output: [t2.x, t1.x, t2.x] }
LogicalScan { table: t2, output_columns: [t2.x, t2.y], required_columns: [x, y], predicate: (t2.x > 100:Int32) }
LogicalScan { table: t1, columns: [t1.x, t1.y] }
- sql: |
create table t1(x int, y int);
create table t2(x int, y int);
Expand Down Expand Up @@ -476,11 +475,10 @@
optimized_logical_plan: |
LogicalJoin { type: LeftSemi, on: (t1.y = t1.y) AND (t1.x = t1.x) AND (t1.y = t1.y), output: all }
LogicalScan { table: t1, columns: [t1.x, t1.y] }
LogicalProject { exprs: [t1.x, t1.y, t1.y] }
LogicalJoin { type: Inner, on: (t1.x = t2.x), output: [t1.x, t1.y] }
LogicalAgg { group_key: [t1.x, t1.y], aggs: [] }
LogicalScan { table: t1, columns: [t1.x, t1.y] }
LogicalScan { table: t2, columns: [t2.x] }
LogicalJoin { type: Inner, on: (t1.x = t2.x), output: [t1.x, t1.y, t1.y] }
LogicalAgg { group_key: [t1.x, t1.y], aggs: [] }
LogicalScan { table: t1, columns: [t1.x, t1.y] }
LogicalScan { table: t2, columns: [t2.x] }
- sql: |
create table t1(x int, y int);
create table t2(x int, y int);
Expand Down Expand Up @@ -554,7 +552,7 @@
LogicalJoin { type: LeftSemi, on: (t2.y = (t3.y + t2.y)) AND (t2.y = t2.y) AND (t2.x = t3.x), output: [t2.x] }
LogicalScan { table: t2, columns: [t2.x, t2.y] }
LogicalProject { exprs: [t3.x, t2.y, (t3.y + t2.y)] }
LogicalJoin { type: Inner, on: true, output: all }
LogicalJoin { type: Inner, on: true, output: [t3.x, t2.y, t3.y] }
LogicalAgg { group_key: [t2.y], aggs: [] }
LogicalScan { table: t2, columns: [t2.y] }
LogicalScan { table: t3, columns: [t3.x, t3.y] }
Expand All @@ -563,13 +561,11 @@
create table t2 (b int, c int);
select a, (select t1.a), c from t1, t2 where t1.b = t2.b order by c;
optimized_logical_plan: |
LogicalProject { exprs: [t1.a, t1.a, t2.c] }
LogicalJoin { type: LeftOuter, on: (t1.a = t1.a), output: [t1.a, t2.c, t1.a] }
LogicalJoin { type: Inner, on: (t1.b = t2.b), output: [t1.a, t2.c] }
LogicalScan { table: t1, columns: [t1.a, t1.b] }
LogicalScan { table: t2, columns: [t2.b, t2.c] }
LogicalProject { exprs: [t1.a, t1.a] }
LogicalJoin { type: Inner, on: true, output: all }
LogicalAgg { group_key: [t1.a], aggs: [] }
LogicalScan { table: t1, columns: [t1.a] }
LogicalValues { rows: [[]], schema: Schema { fields: [] } }
LogicalJoin { type: LeftOuter, on: (t1.a = t1.a), output: [t1.a, t1.a, t2.c] }
LogicalJoin { type: Inner, on: (t1.b = t2.b), output: [t1.a, t2.c] }
LogicalScan { table: t1, columns: [t1.a, t1.b] }
LogicalScan { table: t2, columns: [t2.b, t2.c] }
LogicalJoin { type: Inner, on: true, output: [t1.a, t1.a] }
LogicalAgg { group_key: [t1.a], aggs: [] }
LogicalScan { table: t1, columns: [t1.a] }
LogicalValues { rows: [[]], schema: Schema { fields: [] } }
Loading

0 comments on commit ba66d65

Please sign in to comment.