Skip to content

Commit

Permalink
fix(ssa refactor): Fix flattening pass inserting loads before stores …
Browse files Browse the repository at this point in the history
…occur (#1783)

Fix flattening pass inserting loads before stores occur
  • Loading branch information
jfecher authored Jun 21, 2023
1 parent 46facce commit 4293b15
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 44 deletions.
11 changes: 6 additions & 5 deletions crates/noirc_evaluator/src/ssa_refactor/ir/function_inserter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<'f> FunctionInserter<'f> {
instruction: Instruction,
id: InstructionId,
block: BasicBlockId,
) {
) -> InsertInstructionResult {
let results = self.function.dfg.instruction_results(id);
let results = vecmap(results, |id| self.function.dfg.resolve(*id));

Expand All @@ -79,24 +79,25 @@ impl<'f> FunctionInserter<'f> {
let new_results =
self.function.dfg.insert_instruction_and_results(instruction, block, ctrl_typevars);

Self::insert_new_instruction_results(&mut self.values, &results, new_results);
Self::insert_new_instruction_results(&mut self.values, &results, &new_results);
new_results
}

/// Modify the values HashMap to remember the mapping between an instruction result's previous
/// ValueId (from the source_function) and its new ValueId in the destination function.
pub(crate) fn insert_new_instruction_results(
values: &mut HashMap<ValueId, ValueId>,
old_results: &[ValueId],
new_results: InsertInstructionResult,
new_results: &InsertInstructionResult,
) {
assert_eq!(old_results.len(), new_results.len());

match new_results {
InsertInstructionResult::SimplifiedTo(new_result) => {
values.insert(old_results[0], new_result);
values.insert(old_results[0], *new_result);
}
InsertInstructionResult::Results(new_results) => {
for (old_result, new_result) in old_results.iter().zip(new_results) {
for (old_result, new_result) in old_results.iter().zip(*new_results) {
values.insert(*old_result, *new_result);
}
}
Expand Down
208 changes: 169 additions & 39 deletions crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@
//! v11 = mul v4, Field 12
//! v12 = add v10, v11
//! store v12 at v5 (new store)
use std::collections::HashMap;
use std::{
collections::{HashMap, HashSet},
rc::Rc,
};

use acvm::FieldElement;
use iter_extended::vecmap;
Expand All @@ -144,7 +147,7 @@ use crate::ssa_refactor::{
function::Function,
function_inserter::FunctionInserter,
instruction::{BinaryOp, Instruction, InstructionId, TerminatorInstruction},
types::Type,
types::{CompositeType, Type},
value::ValueId,
},
ssa_gen::Ssa,
Expand Down Expand Up @@ -179,6 +182,13 @@ struct Context<'f> {
/// Maps an address to the old and new value of the element at that address
store_values: HashMap<ValueId, Store>,

/// Stores all allocations local to the current branch.
/// Since these branches are local to the current branch (ie. only defined within one branch of
/// an if expression), they should not be merged with their previous value or stored value in
/// the other branch since there is no such value. The ValueId here is that which is returned
/// by the allocate instruction.
local_allocations: HashSet<ValueId>,

/// A stack of each jmpif condition that was taken to reach a particular point in the program.
/// When two branches are merged back into one, this constitutes a join point, and is analogous
/// to the rest of the program after an if statement. When such a join point / end block is
Expand All @@ -197,6 +207,7 @@ struct Branch {
condition: ValueId,
last_block: BasicBlockId,
store_values: HashMap<ValueId, Store>,
local_allocations: HashSet<ValueId>,
}

fn flatten_function_cfg(function: &mut Function) {
Expand All @@ -211,10 +222,12 @@ fn flatten_function_cfg(function: &mut Function) {
}
let cfg = ControlFlowGraph::with_function(function);
let branch_ends = branch_analysis::find_branch_ends(function, &cfg);

let mut context = Context {
inserter: FunctionInserter::new(function),
cfg,
store_values: HashMap::new(),
local_allocations: HashSet::new(),
branch_ends,
conditions: Vec::new(),
};
Expand Down Expand Up @@ -359,40 +372,60 @@ impl<'f> Context<'f> {
Type::Numeric(_) => {
self.merge_numeric_values(then_condition, else_condition, then_value, else_value)
}
Type::Array(element_types, len) => {
let mut merged = im::Vector::new();

for i in 0..len {
for (element_index, element_type) in element_types.iter().enumerate() {
let index = ((i * element_types.len() + element_index) as u128).into();
let index = self.inserter.function.dfg.make_constant(index, Type::field());

let typevars = Some(vec![element_type.clone()]);

let mut get_element = |array, typevars| {
let get = Instruction::ArrayGet { array, index };
self.insert_instruction_with_typevars(get, typevars).first()
};

let then_element = get_element(then_value, typevars.clone());
let else_element = get_element(else_value, typevars);

merged.push_back(self.merge_values(
then_condition,
else_condition,
then_element,
else_element,
));
}
}

self.inserter.function.dfg.make_array(merged, element_types)
}
Type::Array(element_types, len) => self.merge_array_values(
element_types,
len,
then_condition,
else_condition,
then_value,
else_value,
),
Type::Reference => panic!("Cannot return references from an if expression"),
Type::Function => panic!("Cannot return functions from an if expression"),
}
}

/// Given an if expression that returns an array: `if c { array1 } else { array2 }`,
/// this function will recursively merge array1 and array2 into a single resulting array
/// by creating a new array containing the result of self.merge_values for each element.
fn merge_array_values(
&mut self,
element_types: Rc<CompositeType>,
len: usize,
then_condition: ValueId,
else_condition: ValueId,
then_value: ValueId,
else_value: ValueId,
) -> ValueId {
let mut merged = im::Vector::new();

for i in 0..len {
for (element_index, element_type) in element_types.iter().enumerate() {
let index = ((i * element_types.len() + element_index) as u128).into();
let index = self.inserter.function.dfg.make_constant(index, Type::field());

let typevars = Some(vec![element_type.clone()]);

let mut get_element = |array, typevars| {
let get = Instruction::ArrayGet { array, index };
self.insert_instruction_with_typevars(get, typevars).first()
};

let then_element = get_element(then_value, typevars.clone());
let else_element = get_element(else_value, typevars);

merged.push_back(self.merge_values(
then_condition,
else_condition,
then_element,
else_element,
));
}
}

self.inserter.function.dfg.make_array(merged, element_types)
}

/// Merge two numeric values a and b from separate basic blocks to a single value. This
/// function would return the result of `if c { a } else { b }` as `c*a + (!c)*b`.
fn merge_numeric_values(
Expand Down Expand Up @@ -437,13 +470,18 @@ impl<'f> Context<'f> {
// 'else' case of an if with no else - so there is no else branch.
Branch {
condition: new_condition,
// The last block here is somewhat arbitrary. It only matters that it has no Jmp
// args that will be merged by inline_branch_end. Since jmpifs don't have
// block arguments, it is safe to use the jmpif block here.
last_block: jmpif_block,
store_values: HashMap::new(),
local_allocations: HashSet::new(),
}
} else {
self.push_condition(jmpif_block, new_condition);
self.insert_current_side_effects_enabled();
let old_stores = std::mem::take(&mut self.store_values);
let old_allocations = std::mem::take(&mut self.local_allocations);

// Remember the old condition value is now known to be true/false within this branch
let known_value =
Expand All @@ -453,12 +491,15 @@ impl<'f> Context<'f> {
let final_block = self.inline_block(destination, &[]);

self.conditions.pop();

let stores_in_branch = std::mem::replace(&mut self.store_values, old_stores);
let local_allocations = std::mem::replace(&mut self.local_allocations, old_allocations);

Branch {
condition: new_condition,
last_block: final_block,
store_values: stores_in_branch,
local_allocations,
}
}
}
Expand Down Expand Up @@ -533,14 +574,16 @@ impl<'f> Context<'f> {
}

fn remember_store(&mut self, address: ValueId, new_value: ValueId) {
if let Some(store_value) = self.store_values.get_mut(&address) {
store_value.new_value = new_value;
} else {
let load = Instruction::Load { address };
let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]);
let old_value = self.insert_instruction_with_typevars(load, load_type).first();
if !self.local_allocations.contains(&address) {
if let Some(store_value) = self.store_values.get_mut(&address) {
store_value.new_value = new_value;
} else {
let load = Instruction::Load { address };
let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]);
let old_value = self.insert_instruction_with_typevars(load, load_type).first();

self.store_values.insert(address, Store { old_value, new_value });
self.store_values.insert(address, Store { old_value, new_value });
}
}
}

Expand All @@ -558,6 +601,7 @@ impl<'f> Context<'f> {
// If this is not a separate variable, clippy gets confused and says the to_vec is
// unnecessary, when removing it actually causes an aliasing/mutability error.
let instructions = self.inserter.function.dfg[destination].instructions().to_vec();

for instruction in instructions {
self.push_instruction(instruction);
}
Expand All @@ -574,8 +618,16 @@ impl<'f> Context<'f> {
fn push_instruction(&mut self, id: InstructionId) {
let instruction = self.inserter.map_instruction(id);
let instruction = self.handle_instruction_side_effects(instruction);
let is_allocate = matches!(instruction, Instruction::Allocate);

let entry = self.inserter.function.entry_block();
self.inserter.push_instruction_value(instruction, id, entry);
let results = self.inserter.push_instruction_value(instruction, id, entry);

// Remember an allocate was created local to this branch so that we do not try to merge store
// values across branches for it later.
if is_allocate {
self.local_allocations.insert(results.first());
}
}

/// If we are currently in a branch, we need to modify constrain instructions
Expand Down Expand Up @@ -1020,6 +1072,84 @@ mod test {
assert_eq!(merged_values, vec![3, 5, 6]);
}

#[test]
fn allocate_in_single_branch() {
// Regression test for #1756
// fn foo() -> Field {
// let mut x = 0;
// x
// }
//
// fn main(cond:bool) {
// if cond {
// foo();
// };
// }
//
// // Translates to the following before the flattening pass:
// fn main f2 {
// b0(v0: u1):
// jmpif v0 then: b1, else: b2
// b1():
// v2 = allocate
// store Field 0 at v2
// v4 = load v2
// jmp b2()
// b2():
// return
// }
// The bug is that the flattening pass previously inserted a load
// before the first store to allocate, which loaded an uninitialized value.
// In this test we assert the ordering is strictly Allocate then Store then Load.
let main_id = Id::test_new(0);
let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir);

let b1 = builder.insert_block();
let b2 = builder.insert_block();

let v0 = builder.add_parameter(Type::bool());
builder.terminate_with_jmpif(v0, b1, b2);

builder.switch_to_block(b1);
let v2 = builder.insert_allocate();
let zero = builder.field_constant(0u128);
builder.insert_store(v2, zero);
let _v4 = builder.insert_load(v2, Type::field());
builder.terminate_with_jmp(b2, vec![]);

builder.switch_to_block(b2);
builder.terminate_with_return(vec![]);

let ssa = builder.finish().flatten_cfg();
let main = ssa.main();

// Now assert that there is not a load between the allocate and its first store
// The Expected IR is:
//
// fn main f2 {
// b0(v0: u1):
// enable_side_effects v0
// v6 = allocate
// store Field 0 at v6
// v7 = load v6
// v8 = not v0
// enable_side_effects u1 1
// return
// }
let instructions = main.dfg[main.entry_block()].instructions();

let find_instruction = |predicate: fn(&Instruction) -> bool| {
instructions.iter().position(|id| predicate(&main.dfg[*id])).unwrap()
};

let allocate_index = find_instruction(|i| matches!(i, Instruction::Allocate));
let store_index = find_instruction(|i| matches!(i, Instruction::Store { .. }));
let load_index = find_instruction(|i| matches!(i, Instruction::Load { .. }));

assert!(allocate_index < store_index);
assert!(store_index < load_index);
}

/// Work backwards from an instruction to find all the constant values
/// that were used to construct it. E.g for:
///
Expand Down

0 comments on commit 4293b15

Please sign in to comment.