Skip to content

Commit

Permalink
fix(ssa refactor): Speedup find-branch-ends (#1786)
Browse files Browse the repository at this point in the history
* Speedup find-branch-ends

* Remove timing printlns
  • Loading branch information
jfecher authored Jun 21, 2023
1 parent e560cd2 commit 861e42c
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 117 deletions.
39 changes: 28 additions & 11 deletions crates/noirc_evaluator/src/ssa_refactor/ir/dom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ pub(crate) struct DominatorTree {
/// After dominator tree computation has complete, this will contain a node for every
/// reachable block, and no nodes for unreachable blocks.
nodes: HashMap<BasicBlockId, DominatorTreeNode>,

/// Subsequent calls to `dominates` are cached to speed up access
cache: HashMap<(BasicBlockId, BasicBlockId), bool>,
}

/// Methods for querying the dominator tree.
Expand Down Expand Up @@ -83,7 +86,21 @@ impl DominatorTree {
/// This function panics if either of the blocks are unreachable.
///
/// An instruction is considered to dominate itself.
pub(crate) fn dominates(&self, block_a_id: BasicBlockId, mut block_b_id: BasicBlockId) -> bool {
pub(crate) fn dominates(&mut self, block_a_id: BasicBlockId, block_b_id: BasicBlockId) -> bool {
if let Some(res) = self.cache.get(&(block_a_id, block_b_id)) {
return *res;
}

let result = self.dominates_helper(block_a_id, block_b_id);
self.cache.insert((block_a_id, block_b_id), result);
result
}

pub(crate) fn dominates_helper(
&self,
block_a_id: BasicBlockId,
mut block_b_id: BasicBlockId,
) -> bool {
// Walk up the dominator tree from "b" until we encounter or pass "a". Doing the
// comparison on the reverse post-order may allows to test whether we have passed "a"
// without waiting until we reach the root of the tree.
Expand All @@ -104,7 +121,7 @@ impl DominatorTree {
/// Allocate and compute a dominator tree from a pre-computed control flow graph and
/// post-order counterpart.
pub(crate) fn with_cfg_and_post_order(cfg: &ControlFlowGraph, post_order: &PostOrder) -> Self {
let mut dom_tree = DominatorTree { nodes: HashMap::new() };
let mut dom_tree = DominatorTree { nodes: HashMap::new(), cache: HashMap::new() };
dom_tree.compute_dominator_tree(cfg, post_order);
dom_tree
}
Expand Down Expand Up @@ -249,7 +266,7 @@ mod tests {
block0_id,
TerminatorInstruction::Return { return_values: vec![] },
);
let dom_tree = DominatorTree::with_function(&func);
let mut dom_tree = DominatorTree::with_function(&func);
assert!(dom_tree.dominates(block0_id, block0_id));
}

Expand Down Expand Up @@ -308,7 +325,7 @@ mod tests {
// unreachable, performing this query indicates an internal compiler error.
#[test]
fn unreachable_node_asserts() {
let (dt, b0, _b1, b2, b3) = unreachable_node_setup();
let (mut dt, b0, _b1, b2, b3) = unreachable_node_setup();

assert!(dt.dominates(b0, b0));
assert!(dt.dominates(b0, b2));
Expand All @@ -326,42 +343,42 @@ mod tests {
#[test]
#[should_panic]
fn unreachable_node_panic_b0_b1() {
let (dt, b0, b1, _b2, _b3) = unreachable_node_setup();
let (mut dt, b0, b1, _b2, _b3) = unreachable_node_setup();
dt.dominates(b0, b1);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b1_b0() {
let (dt, b0, b1, _b2, _b3) = unreachable_node_setup();
let (mut dt, b0, b1, _b2, _b3) = unreachable_node_setup();
dt.dominates(b1, b0);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b1_b1() {
let (dt, _b0, b1, _b2, _b3) = unreachable_node_setup();
let (mut dt, _b0, b1, _b2, _b3) = unreachable_node_setup();
dt.dominates(b1, b1);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b1_b2() {
let (dt, _b0, b1, b2, _b3) = unreachable_node_setup();
let (mut dt, _b0, b1, b2, _b3) = unreachable_node_setup();
dt.dominates(b1, b2);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b1_b3() {
let (dt, _b0, b1, _b2, b3) = unreachable_node_setup();
let (mut dt, _b0, b1, _b2, b3) = unreachable_node_setup();
dt.dominates(b1, b3);
}

#[test]
#[should_panic]
fn unreachable_node_panic_b3_b1() {
let (dt, _b0, b1, b2, _b3) = unreachable_node_setup();
let (mut dt, _b0, b1, b2, _b3) = unreachable_node_setup();
dt.dominates(b2, b1);
}

Expand Down Expand Up @@ -390,7 +407,7 @@ mod tests {
let func = ssa.main();
let block0_id = func.entry_block();

let dt = DominatorTree::with_function(func);
let mut dt = DominatorTree::with_function(func);

// Expected dominance tree:
// block0 {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,28 @@
//! This is an algorithm for identifying branch starts and ends.
use std::collections::{HashMap, HashSet};
//!
//! The algorithm is split into two parts:
//! 1. The outer part:
//! A. An (unrolled) CFG can be though of as a linear sequence of blocks where some nodes split
//! off, but eventually rejoin to a new node and continue the linear sequence.
//! B. Follow this sequence in order, and whenever a split is found call
//! `find_join_point_of_branches` and then recur from the join point it returns until the
//! return instruction is found.
//!
//! 2. The inner part defined by `find_join_point_of_branches`:
//! A. For each of the two branches in a jmpif block:
//! - Check if either has multiple predecessors. If so, it is a join point.
//! - If not, continue to search the linear sequence of successor blocks from that block.
//! - If another split point is found, recur in `find_join_point_of_branches`
//! - If a block with multiple predecessors is found, return it.
//! - After, we should have identified a join point for both branches. This is expected to be
//! the same block for both and can be returned from here to continue iteration.
//!
//! This algorithm will remember each join point found in `find_join_point_of_branches` and
//! the resulting map from each split block to each join block is returned.
use std::collections::HashMap;

use crate::ssa_refactor::ir::{
basic_block::BasicBlockId, cfg::ControlFlowGraph, dom::DominatorTree, function::Function,
post_order::PostOrder,
basic_block::BasicBlockId, cfg::ControlFlowGraph, function::Function,
};

/// Returns a `HashMap` mapping blocks that start a branch (i.e. blocks terminated with jmpif) to
Expand All @@ -16,121 +35,78 @@ pub(super) fn find_branch_ends(
function: &Function,
cfg: &ControlFlowGraph,
) -> HashMap<BasicBlockId, BasicBlockId> {
let post_order = PostOrder::with_function(function);
let dom_tree = DominatorTree::with_cfg_and_post_order(cfg, &post_order);
let mut stepper = Stepper::new(function.entry_block());
// This outer `visited` set is inconsequential, and simply here to satisfy the recursive
// stepper interface.
let mut visited = HashSet::new();
let mut branch_ends = HashMap::new();
while !stepper.finished {
stepper.step(cfg, &dom_tree, &mut visited, &mut branch_ends);
}
branch_ends
}
let mut block = function.entry_block();
let mut context = Context::new(cfg);

/// Returns the block at which `left` and `right` converge, at the same time identifying branch
/// ends in any sub branches.
///
/// This function is called by `Stepper::step` and is thus recursive.
fn step_until_rejoin(
cfg: &ControlFlowGraph,
dom_tree: &DominatorTree,
branch_ends: &mut HashMap<BasicBlockId, BasicBlockId>,
left: BasicBlockId,
right: BasicBlockId,
) -> BasicBlockId {
let mut visited = HashSet::new();
let mut left_stepper = Stepper::new(left);
let mut right_stepper = Stepper::new(right);
loop {
let mut successors = cfg.successors(block);

while !left_stepper.finished || !right_stepper.finished {
left_stepper.step(cfg, dom_tree, &mut visited, branch_ends);
right_stepper.step(cfg, dom_tree, &mut visited, branch_ends);
if successors.len() == 2 {
block = context.find_join_point_of_branches(block, successors);
} else if successors.len() == 1 {
block = successors.next().unwrap();
} else if successors.len() == 0 {
// return encountered. We have nothing to join, so we're done
break;
} else {
unreachable!("A block can only have 0, 1, or 2 successors");
}
}
let collision = match (left_stepper.collision, right_stepper.collision) {
(Some(collision), None) | (None, Some(collision)) => collision,
(Some(_),Some(_))=> unreachable!("A collision on both branches indicates a loop"),
_ => unreachable!(
"Until we support multiple returns, branches always re-converge. Once supported this case should return `None`"
),
};
collision

context.branch_ends
}

/// Tracks traversal along the arm of a branch. Steppers are progressed in pairs, such that the
/// re-convergence point of two arms is discovered as soon as possible. The exceptional case is
/// that of the top level stepper, which conveniently steps the whole CFG as if it were a single
/// arm.
struct Stepper {
/// The block that will be interrogated when calling `step`
current_block: BasicBlockId,
/// Indicates that the stepper has no more block successors to process, either because it has
/// reached the end of the CFG, or because it encountered a block already visited by its
/// sibling stepper.
finished: bool,
/// Once finished this option indicates whether a collision was encountered before reaching
/// the end of the CFG.
collision: Option<BasicBlockId>,
struct Context<'cfg> {
branch_ends: HashMap<BasicBlockId, BasicBlockId>,
cfg: &'cfg ControlFlowGraph,
}

impl Stepper {
/// Creates a fresh stepper instance
fn new(current_block: BasicBlockId) -> Self {
Stepper { current_block, finished: false, collision: None }
impl<'cfg> Context<'cfg> {
fn new(cfg: &'cfg ControlFlowGraph) -> Self {
Self { cfg, branch_ends: HashMap::new() }
}

/// Checks the current block to see if it has already been visited and if so marks it as a
/// collision. If a sub-branch is encountered `step_until_rejoin` is called to start a pair
/// of child steppers stepping along its arms.
///
/// It is safe to call this even when the stepper has reached its end.
fn step(
fn find_join_point_of_branches(
&mut self,
cfg: &ControlFlowGraph,
dom_tree: &DominatorTree,
visited: &mut HashSet<BasicBlockId>,
branch_ends: &mut HashMap<BasicBlockId, BasicBlockId>,
) {
if self.finished {
// The caller still needs to progress the other stepper, while this one sits idle.
return;
}
if visited.contains(&self.current_block) {
// The other stepper has already visited this block - thus this block is the
// re.-convergence point.
self.collision = Some(self.current_block);
self.finished = true;
start: BasicBlockId,
mut successors: impl Iterator<Item = BasicBlockId>,
) -> BasicBlockId {
let left = successors.next().unwrap();
let right = successors.next().unwrap();

let left_join = self.find_join_point(left);
let right_join = self.find_join_point(right);

assert_eq!(left_join, right_join, "Expected two blocks to join to the same block");
self.branch_ends.insert(start, left_join);

left_join
}

fn find_join_point(&mut self, block: BasicBlockId) -> BasicBlockId {
let predecessors = self.cfg.predecessors(block);
if predecessors.len() > 1 {
return block;
}
visited.insert(self.current_block);
// The join point is not this block, so continue on
self.skip_then_find_join_point(block)
}

let mut successors = cfg.successors(self.current_block);
match successors.len() {
0 => {
// Reached the end of the CFG without a collision - this will happen in the other
// stepper assuming the CFG contains no early returns.
self.finished = true;
}
1 => {
// This block doesn't describe any branch starts or ends - move on.
self.current_block = successors.next().unwrap();
}
2 => {
// Sub-branch start encountered - recurse to find the end of the sub branch
let left = successors.next().unwrap();
let right = successors.next().unwrap();
let sub_branch_end = step_until_rejoin(cfg, dom_tree, branch_ends, left, right);
for collision_predecessor in cfg.predecessors(sub_branch_end) {
assert!(dom_tree.dominates(self.current_block, collision_predecessor));
}
branch_ends.insert(self.current_block, sub_branch_end);
fn skip_then_find_join_point(&mut self, block: BasicBlockId) -> BasicBlockId {
let mut successors = self.cfg.successors(block);

// Resume stepping though the current arm fro where the sub-branch left off
self.current_block = sub_branch_end;
}
_ => {
unreachable!("Basic blocks never have more than 2 successors")
}
if successors.len() == 2 {
let join = self.find_join_point_of_branches(block, successors);
// Note that we call skip_then_find_join_point here instead of find_join_point.
// We already know this `join` is a join point, but it cannot be for the current block
// since we already know it is the join point of the successors of the current block.
self.skip_then_find_join_point(join)
} else if successors.len() == 1 {
self.find_join_point(successors.next().unwrap())
} else if successors.len() == 0 {
unreachable!("return encountered before a join point was found. This can only happen if early-return was added to the language without implementing it by jmping to a join block first")
} else {
unreachable!("A block can only have 0, 1, or 2 successors");
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/noirc_evaluator/src/ssa_refactor/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct Loops {
fn find_all_loops(function: &Function) -> Loops {
let cfg = ControlFlowGraph::with_function(function);
let post_order = PostOrder::with_function(function);
let dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order);
let mut dom_tree = DominatorTree::with_cfg_and_post_order(&cfg, &post_order);

let mut loops = vec![];

Expand Down

0 comments on commit 861e42c

Please sign in to comment.