Skip to content

Commit

Permalink
feat: Bushy tree join ordering (risingwavelabs#8316)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Axel <kevinaxel@163.com>
  • Loading branch information
KveinAxel authored Mar 18, 2023
1 parent a3dc882 commit 53261c5
Show file tree
Hide file tree
Showing 8 changed files with 1,736 additions and 14 deletions.
21 changes: 20 additions & 1 deletion src/common/src/session_config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::util::epoch::Epoch;

// This is a hack, &'static str is not allowed as a const generics argument.
// TODO: refine this using the adt_const_params feature.
const CONFIG_KEYS: [&str; 21] = [
const CONFIG_KEYS: [&str; 22] = [
"RW_IMPLICIT_FLUSH",
"CREATE_COMPACTION_GROUP_FOR_MV",
"QUERY_MODE",
Expand All @@ -56,6 +56,7 @@ const CONFIG_KEYS: [&str; 21] = [
"RW_ENABLE_SHARE_PLAN",
"INTERVALSTYLE",
"BATCH_PARALLELISM",
"RW_STREAMING_ENABLE_BUSHY_JOIN",
];

// MUST HAVE 1v1 relationship to CONFIG_KEYS. e.g. CONFIG_KEYS[IMPLICIT_FLUSH] =
Expand All @@ -81,6 +82,7 @@ const FORCE_TWO_PHASE_AGG: usize = 17;
const RW_ENABLE_SHARE_PLAN: usize = 18;
const INTERVAL_STYLE: usize = 19;
const BATCH_PARALLELISM: usize = 20;
const STREAMING_ENABLE_BUSHY_JOIN: usize = 21;

trait ConfigEntry: Default + for<'a> TryFrom<&'a [&'a str], Error = RwError> {
fn entry_name() -> &'static str;
Expand Down Expand Up @@ -277,6 +279,7 @@ type QueryEpoch = ConfigU64<QUERY_EPOCH, 0>;
type Timezone = ConfigString<TIMEZONE>;
type StreamingParallelism = ConfigU64<STREAMING_PARALLELISM, 0>;
type StreamingEnableDeltaJoin = ConfigBool<STREAMING_ENABLE_DELTA_JOIN, false>;
type StreamingEnableBushyJoin = ConfigBool<STREAMING_ENABLE_BUSHY_JOIN, false>;
type EnableTwoPhaseAgg = ConfigBool<ENABLE_TWO_PHASE_AGG, true>;
type ForceTwoPhaseAgg = ConfigBool<FORCE_TWO_PHASE_AGG, false>;
type EnableSharePlan = ConfigBool<RW_ENABLE_SHARE_PLAN, true>;
Expand Down Expand Up @@ -342,6 +345,9 @@ pub struct ConfigMap {
/// Enable delta join in streaming query. Defaults to false.
streaming_enable_delta_join: StreamingEnableDeltaJoin,

/// Enable bushy join in the streaming query. Defaults to false.
streaming_enable_bushy_join: StreamingEnableBushyJoin,

/// Enable two phase agg optimization. Defaults to true.
/// Setting this to true will always set `FORCE_TWO_PHASE_AGG` to false.
enable_two_phase_agg: EnableTwoPhaseAgg,
Expand Down Expand Up @@ -402,6 +408,8 @@ impl ConfigMap {
self.streaming_parallelism = val.as_slice().try_into()?;
} else if key.eq_ignore_ascii_case(StreamingEnableDeltaJoin::entry_name()) {
self.streaming_enable_delta_join = val.as_slice().try_into()?;
} else if key.eq_ignore_ascii_case(StreamingEnableBushyJoin::entry_name()) {
self.streaming_enable_bushy_join = val.as_slice().try_into()?;
} else if key.eq_ignore_ascii_case(EnableTwoPhaseAgg::entry_name()) {
self.enable_two_phase_agg = val.as_slice().try_into()?;
if !*self.enable_two_phase_agg {
Expand Down Expand Up @@ -458,6 +466,8 @@ impl ConfigMap {
Ok(self.streaming_parallelism.to_string())
} else if key.eq_ignore_ascii_case(StreamingEnableDeltaJoin::entry_name()) {
Ok(self.streaming_enable_delta_join.to_string())
} else if key.eq_ignore_ascii_case(StreamingEnableBushyJoin::entry_name()) {
Ok(self.streaming_enable_bushy_join.to_string())
} else if key.eq_ignore_ascii_case(EnableTwoPhaseAgg::entry_name()) {
Ok(self.enable_two_phase_agg.to_string())
} else if key.eq_ignore_ascii_case(ForceTwoPhaseAgg::entry_name()) {
Expand Down Expand Up @@ -550,6 +560,11 @@ impl ConfigMap {
setting : self.streaming_enable_delta_join.to_string(),
description: String::from("Enable delta join in streaming query.")
},
VariableInfo{
name : StreamingEnableBushyJoin::entry_name().to_lowercase(),
setting : self.streaming_enable_bushy_join.to_string(),
description: String::from("Enable bushy join in streaming query.")
},
VariableInfo{
name : EnableTwoPhaseAgg::entry_name().to_lowercase(),
setting : self.enable_two_phase_agg.to_string(),
Expand Down Expand Up @@ -648,6 +663,10 @@ impl ConfigMap {
*self.streaming_enable_delta_join
}

pub fn get_streaming_enable_bushy_join(&self) -> bool {
*self.streaming_enable_bushy_join
}

pub fn get_enable_two_phase_agg(&self) -> bool {
*self.enable_two_phase_agg
}
Expand Down
1,408 changes: 1,408 additions & 0 deletions src/frontend/planner_test/tests/testdata/bushy_join.yaml

Large diffs are not rendered by default.

26 changes: 20 additions & 6 deletions src/frontend/src/optimizer/logical_optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,15 @@ lazy_static! {
ApplyOrder::TopDown,
);

static ref JOIN_REORDER: OptimizationStage = OptimizationStage::new(
static ref LEFT_DEEP_JOIN_REORDER: OptimizationStage = OptimizationStage::new(
"Join Reorder".to_string(),
vec![ReorderMultiJoinRule::create()],
vec![LeftDeepTreeJoinOrderingRule::create()],
ApplyOrder::TopDown,
);

static ref BUSHY_TREE_JOIN_REORDER: OptimizationStage = OptimizationStage::new(
"Bushy tree join ordering Rule".to_string(),
vec![BushyTreeJoinOrderingRule::create()],
ApplyOrder::TopDown,
);

Expand Down Expand Up @@ -365,9 +371,17 @@ impl LogicalOptimizer {
// their relevant joins.
plan = plan.optimize_by_rules(&TO_MULTI_JOIN);

// Reorder multijoin into left-deep join tree.
plan = plan.optimize_by_rules(&JOIN_REORDER);

// Reorder multijoin into join tree.
if plan
.ctx()
.session_ctx()
.config()
.get_streaming_enable_bushy_join()
{
plan = plan.optimize_by_rules(&BUSHY_TREE_JOIN_REORDER);
} else {
plan = plan.optimize_by_rules(&LEFT_DEEP_JOIN_REORDER);
}
// Predicate Push-down: apply filter pushdown rules again since we pullup all join
// conditions into a filter above the multijoin.
plan = Self::predicate_pushdown(plan, explain_trace, &ctx);
Expand Down Expand Up @@ -438,7 +452,7 @@ impl LogicalOptimizer {
plan = plan.optimize_by_rules(&TO_MULTI_JOIN);

// Reorder multijoin into left-deep join tree.
plan = plan.optimize_by_rules(&JOIN_REORDER);
plan = plan.optimize_by_rules(&LEFT_DEEP_JOIN_REORDER);

// Predicate Push-down: apply filter pushdown rules again since we pullup all join
// conditions into a filter above the multijoin.
Expand Down
239 changes: 239 additions & 0 deletions src/frontend/src/optimizer/plan_node/logical_multi_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cmp::Ordering;
use std::collections::{BTreeMap, BTreeSet, VecDeque};
use std::fmt;

use itertools::Itertools;
Expand Down Expand Up @@ -483,6 +485,243 @@ impl LogicalMultiJoin {
Ok(join_ordering)
}

pub fn as_bushy_tree_join(&self) -> Result<PlanRef> {
// Join tree internal representation
#[derive(Clone, Default, Debug)]
struct JoinTreeNode {
idx: Option<usize>,
left: Option<Box<JoinTreeNode>>,
right: Option<Box<JoinTreeNode>>,
height: usize,
}

// join graph internal representation
#[derive(Clone, Debug)]
struct GraphNode {
id: usize,
join_tree: JoinTreeNode,
// use BTreeSet for deterministic
relations: BTreeSet<usize>,
}

let mut nodes: BTreeMap<_, _> = (0..self.inputs.len())
.map(|idx| GraphNode {
id: idx,
relations: BTreeSet::new(),
join_tree: JoinTreeNode {
idx: Some(idx),
left: None,
right: None,
height: 0,
},
})
.enumerate()
.collect();
let (eq_join_conditions, _) = self
.on
.clone()
.split_by_input_col_nums(&self.input_col_nums(), true);

for ((src, dst), _) in eq_join_conditions {
nodes.get_mut(&src).unwrap().relations.insert(dst);
nodes.get_mut(&dst).unwrap().relations.insert(src);
}

// isolated nodes can be joined at any where.
let iso_nodes = nodes
.iter()
.filter_map(|n| {
if n.1.relations.is_empty() {
Some(*n.0)
} else {
None
}
})
.collect_vec();

for n in iso_nodes {
for adj in 0..nodes.len() {
if adj != n {
nodes.get_mut(&n).unwrap().relations.insert(adj);
nodes.get_mut(&adj).unwrap().relations.insert(n);
}
}
}

let mut optimized_bushy_tree = None;
let mut que = VecDeque::from([nodes]);
let mut isolated = BTreeSet::new();

while let Some(mut nodes) = que.pop_front() {
if nodes.len() == 1 {
let node = nodes.into_values().next().unwrap();
optimized_bushy_tree = Some(optimized_bushy_tree.map_or(
node.clone(),
|old_tree: GraphNode| {
if node.join_tree.height < old_tree.join_tree.height {
node
} else {
old_tree
}
},
));
continue;
}

let (idx, _) = nodes
.iter()
.min_by(
|(_, x), (_, y)| match x.relations.len().cmp(&y.relations.len()) {
Ordering::Less => Ordering::Less,
Ordering::Greater => Ordering::Greater,
Ordering::Equal => x.join_tree.height.cmp(&y.join_tree.height),
},
)
.unwrap();
let n = nodes.remove(&idx.clone()).unwrap();

if n.relations.is_empty() {
isolated.insert(n.id);
que.push_back(nodes);
continue;
}

for merge_node in &n.relations {
let mut nodes = nodes.clone();
for adjacent_node in &n.relations {
if *adjacent_node != *merge_node {
nodes
.get_mut(adjacent_node)
.unwrap()
.relations
.remove(&n.id);
nodes
.get_mut(adjacent_node)
.unwrap()
.relations
.insert(*merge_node);
nodes
.get_mut(merge_node)
.unwrap()
.relations
.insert(*adjacent_node);
}
}
let mut merge_graph_node = nodes.get_mut(merge_node).unwrap();
merge_graph_node.relations.remove(&n.id);
let l_tree = n.join_tree.clone();
let r_tree = std::mem::take(&mut merge_graph_node.join_tree);
let new_height = usize::max(l_tree.height, r_tree.height) + 1;

if let Some(min_height) = optimized_bushy_tree.as_ref().map(|t| t.join_tree.height) && min_height < new_height {
continue;
}

merge_graph_node.join_tree = JoinTreeNode {
idx: None,
left: Some(Box::new(l_tree)),
right: Some(Box::new(r_tree)),
height: new_height,
};
que.push_back(nodes);
}
}

fn create_logical_join(
s: &LogicalMultiJoin,
mut join_tree: JoinTreeNode,
join_ordering: &mut Vec<usize>,
) -> Result<PlanRef> {
Ok(match (join_tree.left.take(), join_tree.right.take()) {
(Some(l), Some(r)) => LogicalJoin::new(
create_logical_join(s, *l, join_ordering)?,
create_logical_join(s, *r, join_ordering)?,
JoinType::Inner,
Condition::true_cond(),
)
.into(),
(None, None) => {
if let Some(idx) = join_tree.idx {
join_ordering.push(idx);
s.inputs[idx].clone()
} else {
return Err(RwError::from(ErrorCode::InternalError(
"id of the leaf node not found in the join tree".into(),
)));
}
}
(_, _) => {
return Err(RwError::from(ErrorCode::InternalError(
"only leaf node can have None subtree".into(),
)))
}
})
}

let isolated = isolated.into_iter().collect_vec();
let mut join_ordering = vec![];
let mut output = if let Some(optimized_bushy_tree) = optimized_bushy_tree {
let mut output =
create_logical_join(self, optimized_bushy_tree.join_tree, &mut join_ordering)?;

output = isolated.into_iter().fold(output, |chain, n| {
join_ordering.push(n);
LogicalJoin::new(
chain,
self.inputs[n].clone(),
JoinType::Inner,
Condition::true_cond(),
)
.into()
});
output
} else if !isolated.is_empty() {
let base = isolated[0];
join_ordering.push(isolated[0]);
isolated[1..]
.iter()
.fold(self.inputs[base].clone(), |chain, n| {
join_ordering.push(*n);
LogicalJoin::new(
chain,
self.inputs[*n].clone(),
JoinType::Inner,
Condition::true_cond(),
)
.into()
})
} else {
return Err(RwError::from(ErrorCode::InternalError(
"no plan remain".into(),
)));
};
let total_col_num = self.inner2output.source_size();
let reorder_mapping = {
let mut reorder_mapping = vec![None; total_col_num];

join_ordering
.iter()
.cloned()
.flat_map(|input_idx| {
(0..self.inputs[input_idx].schema().len())
.map(move |col_idx| self.inner_i2o_mappings[input_idx].map(col_idx))
})
.enumerate()
.for_each(|(tar, src)| reorder_mapping[src] = Some(tar));
reorder_mapping
};
output =
LogicalProject::with_out_col_idx(output, reorder_mapping.iter().map(|i| i.unwrap()))
.into();

// We will later push down all of the filters back to the individual joins via the
// `FilterJoinRule`.
output = LogicalFilter::create(output, self.on.clone());
output =
LogicalProject::with_out_col_idx(output, self.output_indices.iter().cloned()).into();
Ok(output)
}

pub(crate) fn input_col_nums(&self) -> Vec<usize> {
self.inputs.iter().map(|i| i.schema().len()).collect()
}
Expand Down
Loading

0 comments on commit 53261c5

Please sign in to comment.