Skip to content

Commit

Permalink
Don't treat OpSelectionMerge as an edge in mem2reg (#571)
Browse files Browse the repository at this point in the history
* Don't treat OpSelectionMerge as an edge in mem2reg

* Make outgoing_edges not allocate

* Update test
  • Loading branch information
khyperia authored Apr 1, 2021
1 parent 4fa73bd commit f0f0f31
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 71 deletions.
16 changes: 15 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//! run mem2reg (see mem2reg.rs) on the result to "unwrap" the Function pointer.

use super::apply_rewrite_rules;
use super::mem2reg::compute_preds;
use super::simple_passes::outgoing_edges;
use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
use std::collections::{HashMap, HashSet};
Expand Down Expand Up @@ -413,3 +413,17 @@ fn fuse_trivial_branches(function: &mut Function) {
}
function.blocks.retain(|b| !b.instructions.is_empty());
}

fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
let mut result = vec![vec![]; blocks.len()];
for (source_idx, source) in blocks.iter().enumerate() {
for dest_id in outgoing_edges(source) {
let dest_idx = blocks
.iter()
.position(|b| b.label_id().unwrap() == dest_id)
.unwrap();
result[dest_idx].push(source_idx);
}
}
result
}
78 changes: 48 additions & 30 deletions crates/rustc_codegen_spirv/src/linker/mem2reg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ pub fn mem2reg(
constants: &HashMap<Word, u32>,
func: &mut Function,
) {
let preds = compute_preds(&func.blocks);
let idom = compute_idom(&preds);
let reachable = compute_reachable(&func.blocks);
let preds = compute_preds(&func.blocks, &reachable);
let idom = compute_idom(&preds, &reachable);
let dominance_frontier = compute_dominance_frontier(&preds, &idom);
insert_phis_all(
header,
Expand All @@ -35,24 +36,38 @@ pub fn mem2reg(
);
}

pub fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
let mut result = vec![vec![]; blocks.len()];
for (source_idx, source) in blocks.iter().enumerate() {
let mut edges = outgoing_edges(source);
// HACK(eddyb) treat `OpSelectionMerge` as an edge, in case it points
// to an otherwise-unreachable block.
if let Some(before_last_idx) = source.instructions.len().checked_sub(2) {
if let Some(before_last) = source.instructions.get(before_last_idx) {
if before_last.class.opcode == Op::SelectionMerge {
edges.push(before_last.operands[0].unwrap_id_ref());
}
fn label_to_index(blocks: &[Block], id: Word) -> usize {
blocks
.iter()
.position(|b| b.label_id().unwrap() == id)
.unwrap()
}

fn compute_reachable(blocks: &[Block]) -> Vec<bool> {
fn recurse(blocks: &[Block], reachable: &mut [bool], block: usize) {
if !reachable[block] {
reachable[block] = true;
for dest_id in outgoing_edges(&blocks[block]) {
let dest_idx = label_to_index(blocks, dest_id);
recurse(blocks, reachable, dest_idx);
}
}
for dest_id in edges {
let dest_idx = blocks
.iter()
.position(|b| b.label_id().unwrap() == dest_id)
.unwrap();
}
let mut reachable = vec![false; blocks.len()];
recurse(blocks, &mut reachable, 0);
reachable
}

fn compute_preds(blocks: &[Block], reachable_blocks: &[bool]) -> Vec<Vec<usize>> {
let mut result = vec![vec![]; blocks.len()];
// Do not count unreachable blocks as valid preds of blocks
for (source_idx, source) in blocks
.iter()
.enumerate()
.filter(|&(b, _)| reachable_blocks[b])
{
for dest_id in outgoing_edges(source) {
let dest_idx = label_to_index(blocks, dest_id);
result[dest_idx].push(source_idx);
}
}
Expand All @@ -62,7 +77,8 @@ pub fn compute_preds(blocks: &[Block]) -> Vec<Vec<usize>> {
// Paper: A Simple, Fast Dominance Algorithm
// https://www.cs.rice.edu/~keith/EMBED/dom.pdf
// Note: requires nodes in reverse postorder
fn compute_idom(preds: &[Vec<usize>]) -> Vec<usize> {
// If a result is None, that means the block is unreachable, and therefore has no idom.
fn compute_idom(preds: &[Vec<usize>], reachable_blocks: &[bool]) -> Vec<Option<usize>> {
fn intersect(doms: &[Option<usize>], mut finger1: usize, mut finger2: usize) -> usize {
// TODO: This may return an optional result?
while finger1 != finger2 {
Expand All @@ -83,7 +99,8 @@ fn compute_idom(preds: &[Vec<usize>]) -> Vec<usize> {
let mut changed = true;
while changed {
changed = false;
for node in 1..(preds.len()) {
// Unreachable blocks have no preds, and therefore no idom
for node in (1..(preds.len())).filter(|&i| reachable_blocks[i]) {
let mut new_idom: Option<usize> = None;
for &pred in &preds[node] {
if idom[pred].is_some() {
Expand All @@ -99,20 +116,25 @@ fn compute_idom(preds: &[Vec<usize>]) -> Vec<usize> {
}
}
}
idom.iter().map(|x| x.unwrap()).collect()
assert!(idom
.iter()
.enumerate()
.all(|(i, x)| x.is_some() == reachable_blocks[i]));
idom
}

// Same paper as above
fn compute_dominance_frontier(preds: &[Vec<usize>], idom: &[usize]) -> Vec<HashSet<usize>> {
fn compute_dominance_frontier(preds: &[Vec<usize>], idom: &[Option<usize>]) -> Vec<HashSet<usize>> {
assert_eq!(preds.len(), idom.len());
let mut dominance_frontier = vec![HashSet::new(); preds.len()];
for node in 0..preds.len() {
if preds[node].len() >= 2 {
let node_idom = idom[node].unwrap();
for &pred in &preds[node] {
let mut runner = pred;
while runner != idom[node] {
while runner != node_idom {
dominance_frontier[runner].insert(node);
runner = idom[runner];
runner = idom[runner].unwrap();
}
}
}
Expand Down Expand Up @@ -442,13 +464,9 @@ impl Renamer<'_> {
}
}

for dest_id in outgoing_edges(&self.blocks[block]) {
for dest_id in outgoing_edges(&self.blocks[block]).collect::<Vec<_>>() {
// TODO: Don't do this find
let dest_idx = self
.blocks
.iter()
.position(|b| b.label_id().unwrap() == dest_id)
.unwrap();
let dest_idx = label_to_index(self.blocks, dest_id);
self.rename(dest_idx, Some(block));
}

Expand Down
4 changes: 3 additions & 1 deletion crates/rustc_codegen_spirv/src/linker/new_structurizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,9 @@ impl Structurizer<'_> {
}
visited[block] = true;

for target in super::simple_passes::outgoing_edges(&self.func.blocks()[block]) {
for target in
super::simple_passes::outgoing_edges(&self.func.blocks()[block]).collect::<Vec<_>>()
{
self.post_order_step(self.block_id_to_idx[&target], visited, post_order)
}

Expand Down
29 changes: 9 additions & 20 deletions crates/rustc_codegen_spirv/src/linker/simple_passes.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use rspirv::dr::{Block, Function, Module};
use rspirv::spirv::{Op, Word};
use std::collections::{HashMap, HashSet};
use std::iter::once;
use std::mem::replace;

pub fn shift_ids(module: &mut Module, add: u32) {
Expand Down Expand Up @@ -44,7 +43,7 @@ pub fn block_ordering_pass(func: &mut Function) {
.iter()
.find(|b| b.label_id().unwrap() == current)
.unwrap();
let mut edges = outgoing_edges(current_block);
let mut edges = outgoing_edges(current_block).collect::<Vec<_>>();
// HACK(eddyb) treat `OpSelectionMerge` as an edge, in case it points
// to an otherwise-unreachable block.
if let Some(before_last_idx) = current_block.instructions.len().checked_sub(2) {
Expand Down Expand Up @@ -80,27 +79,17 @@ pub fn block_ordering_pass(func: &mut Function) {
assert_eq!(func.blocks[0].label_id().unwrap(), entry_label);
}

// FIXME(eddyb) use `Either`, `Cow`, and/or `SmallVec`.
pub fn outgoing_edges(block: &Block) -> Vec<Word> {
pub fn outgoing_edges(block: &Block) -> impl Iterator<Item = Word> + '_ {
let terminator = block.instructions.last().unwrap();
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Termination
match terminator.class.opcode {
Op::Branch => vec![terminator.operands[0].unwrap_id_ref()],
Op::BranchConditional => vec![
terminator.operands[1].unwrap_id_ref(),
terminator.operands[2].unwrap_id_ref(),
],
Op::Switch => once(terminator.operands[1].unwrap_id_ref())
.chain(
terminator.operands[3..]
.iter()
.step_by(2)
.map(|op| op.unwrap_id_ref()),
)
.collect(),
Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => Vec::new(),
let operand_indices = match terminator.class.opcode {
Op::Branch => (0..1).step_by(1),
Op::BranchConditional => (1..3).step_by(1),
Op::Switch => (1..terminator.operands.len()).step_by(2),
Op::Return | Op::ReturnValue | Op::Kill | Op::Unreachable => (0..0).step_by(1),
_ => panic!("Invalid block terminator: {:?}", terminator),
}
};
operand_indices.map(move |i| terminator.operands[i].unwrap_id_ref())
}

pub fn compact_ids(module: &mut Module) -> u32 {
Expand Down
30 changes: 15 additions & 15 deletions crates/rustc_codegen_spirv/src/linker/structurizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ fn retarget_loop_children_if_needed(builder: &mut Builder, cf_info: &ControlFlow

while let Some(front) = next.pop_front() {
let block_idx = find_block_index_from_id(builder, &front);
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
let mut new_edges =
outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();

// Make sure we are not looping or going into child loops.
for loop_info in &cf_info.loops {
Expand Down Expand Up @@ -249,8 +250,7 @@ fn retarget_loop_children_if_needed(builder: &mut Builder, cf_info: &ControlFlow
fn incoming_edges(id: Word, builder: &mut Builder) -> Vec<Word> {
let mut incoming_edges = Vec::new();
for block in get_blocks_ref(builder) {
let out = outgoing_edges(block);
if out.contains(&id) {
if outgoing_edges(block).any(|x| x == id) {
incoming_edges.push(block.label_id().unwrap());
}
}
Expand Down Expand Up @@ -292,8 +292,8 @@ fn defer_loop_internals(builder: &mut Builder, cf_info: &ControlFlowInfo) {
// find all blocks that branch to a merge block.
let mut possible_intermediate_block_idexes = Vec::new();
for (i, block) in get_blocks_ref(builder).iter().enumerate() {
let out = outgoing_edges(block);
if out.len() == 1 && out[0] == loop_info.merge_id {
let mut out = outgoing_edges(block);
if out.next() == Some(loop_info.merge_id) && out.next() == None {
possible_intermediate_block_idexes.push(i)
}
}
Expand Down Expand Up @@ -390,7 +390,7 @@ fn eliminate_multiple_continue_blocks(builder: &mut Builder, header: Word) -> Wo
for block in get_blocks_ref(builder) {
let block_id = block.label_id().unwrap();
if ends_in_branch(block) {
let edge = outgoing_edges(block)[0];
let edge = outgoing_edges(block).next().unwrap();
if edge == header && block_is_parent_of(builder, header, block_id) {
continue_blocks.push(block_id);
}
Expand Down Expand Up @@ -421,7 +421,7 @@ fn block_leads_into_break(builder: &Builder, cf_info: &ControlFlowInfo, start: W

while let Some(front) = next.pop_front() {
let block_idx = find_block_index_from_id(builder, &front);
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();

// Make sure we are not looping.
for loop_info in &cf_info.loops {
Expand Down Expand Up @@ -456,7 +456,7 @@ fn block_leads_into_break(builder: &Builder, cf_info: &ControlFlowInfo, start: W

fn block_leads_into_continue(builder: &Builder, cf_info: &ControlFlowInfo, start: Word) -> bool {
let start_idx = find_block_index_from_id(builder, &start);
let new_edges = outgoing_edges(get_block_ref!(builder, start_idx));
let new_edges = outgoing_edges(get_block_ref!(builder, start_idx)).collect::<Vec<_>>();
for loop_info in &cf_info.loops {
if new_edges.len() == 1 && loop_info.continue_id == new_edges[0] {
return true;
Expand Down Expand Up @@ -485,7 +485,7 @@ fn block_is_reverse_idom_of(
continue;
}

let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();

// Skip loop bodies by jumping to the merge block is we hit a header block.
for loop_info in &cf_info.loops {
Expand Down Expand Up @@ -552,7 +552,7 @@ fn block_is_parent_of(builder: &Builder, parent: Word, child: Word) -> bool {

while let Some(front) = next.pop_front() {
let block_idx = find_block_index_from_id(builder, &front);
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();

for id in &processed {
if let Some(i) = new_edges.iter().position(|x| x == id) {
Expand Down Expand Up @@ -589,7 +589,7 @@ fn get_looping_branch_from_block(
}

let block_idx = find_block_index_from_id(builder, &front);
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx));
let mut new_edges = outgoing_edges(get_block_ref!(builder, block_idx)).collect::<Vec<_>>();

let edge_it = new_edges.iter().find(|&x| x == &start); // Check if the new_edges contain the start
if new_edges.len() == 1 {
Expand All @@ -598,8 +598,8 @@ fn get_looping_branch_from_block(
let start_idx = find_block_index_from_id(builder, &front);
let start_edges = outgoing_edges(get_block_ref!(builder, start_idx));

for (i, start_edge) in start_edges.iter().enumerate() {
if start_edge == edge_it || block_is_parent_of(builder, *start_edge, *edge_it) {
for (i, start_edge) in start_edges.enumerate() {
if start_edge == *edge_it || block_is_parent_of(builder, start_edge, *edge_it) {
return Some(i);
}
}
Expand Down Expand Up @@ -696,7 +696,7 @@ pub fn insert_selection_merge_on_conditional_branch(
};

let bi = find_block_index_from_id(builder, id);
let out = outgoing_edges(&get_blocks_ref(builder)[bi]);
let out = outgoing_edges(&get_blocks_ref(builder)[bi]).collect::<Vec<_>>();
let id = idx_to_id(builder, bi);
let a_nexts = get_possible_merge_positions(builder, cf_info, out[0]);
let b_nexts = get_possible_merge_positions(builder, cf_info, out[1]);
Expand Down Expand Up @@ -809,7 +809,7 @@ pub fn insert_loop_merge_on_conditional_branch(

let merge_branch_idx = (looping_branch_idx + 1) % 2;
let bi = find_block_index_from_id(builder, &id);
let out = outgoing_edges(&get_blocks_ref(builder)[bi]);
let out = outgoing_edges(&get_blocks_ref(builder)[bi]).collect::<Vec<_>>();

let continue_block_id = eliminate_multiple_continue_blocks(builder, id);
let merge_block_id = out[merge_branch_idx];
Expand Down
6 changes: 2 additions & 4 deletions crates/spirv-builder/src/test/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,12 @@ OpSelectionMerge %24 None
OpBranchConditional %22 %25 %26
%25 = OpLabel
%27 = OpIMul %2 %28 %14
%29 = OpIAdd %2 %27 %5
%30 = OpIAdd %10 %9 %31
%15 = OpIAdd %2 %27 %5
%12 = OpIAdd %10 %9 %29
OpBranch %24
%26 = OpLabel
OpReturnValue %14
%24 = OpLabel
%12 = OpPhi %10 %30 %25
%15 = OpPhi %2 %29 %25
%19 = OpPhi %17 %18 %25
OpBranch %13
%13 = OpLabel
Expand Down

0 comments on commit f0f0f31

Please sign in to comment.