diff --git a/crates/noirc_evaluator/src/ssa_refactor/ir/function_inserter.rs b/crates/noirc_evaluator/src/ssa_refactor/ir/function_inserter.rs index 31f236765c9..530e8d25523 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/ir/function_inserter.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/ir/function_inserter.rs @@ -68,7 +68,7 @@ impl<'f> FunctionInserter<'f> { instruction: Instruction, id: InstructionId, block: BasicBlockId, - ) { + ) -> InsertInstructionResult { let results = self.function.dfg.instruction_results(id); let results = vecmap(results, |id| self.function.dfg.resolve(*id)); @@ -79,7 +79,8 @@ impl<'f> FunctionInserter<'f> { let new_results = self.function.dfg.insert_instruction_and_results(instruction, block, ctrl_typevars); - Self::insert_new_instruction_results(&mut self.values, &results, new_results); + Self::insert_new_instruction_results(&mut self.values, &results, &new_results); + new_results } /// Modify the values HashMap to remember the mapping between an instruction result's previous @@ -87,16 +88,16 @@ impl<'f> FunctionInserter<'f> { pub(crate) fn insert_new_instruction_results( values: &mut HashMap, old_results: &[ValueId], - new_results: InsertInstructionResult, + new_results: &InsertInstructionResult, ) { assert_eq!(old_results.len(), new_results.len()); match new_results { InsertInstructionResult::SimplifiedTo(new_result) => { - values.insert(old_results[0], new_result); + values.insert(old_results[0], *new_result); } InsertInstructionResult::Results(new_results) => { - for (old_result, new_result) in old_results.iter().zip(new_results) { + for (old_result, new_result) in old_results.iter().zip(*new_results) { values.insert(*old_result, *new_result); } } diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs index d8fc52f6f92..39cae09922c 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/flatten_cfg.rs @@ -131,7 +131,10 @@ //! v11 = mul v4, Field 12 //! v12 = add v10, v11 //! store v12 at v5 (new store) -use std::collections::HashMap; +use std::{ + collections::{HashMap, HashSet}, + rc::Rc, +}; use acvm::FieldElement; use iter_extended::vecmap; @@ -144,7 +147,7 @@ use crate::ssa_refactor::{ function::Function, function_inserter::FunctionInserter, instruction::{BinaryOp, Instruction, InstructionId, TerminatorInstruction}, - types::Type, + types::{CompositeType, Type}, value::ValueId, }, ssa_gen::Ssa, @@ -179,6 +182,13 @@ struct Context<'f> { /// Maps an address to the old and new value of the element at that address store_values: HashMap, + /// Stores all allocations local to the current branch. + /// Since these branches are local to the current branch (ie. only defined within one branch of + /// an if expression), they should not be merged with their previous value or stored value in + /// the other branch since there is no such value. The ValueId here is that which is returned + /// by the allocate instruction. + local_allocations: HashSet, + /// A stack of each jmpif condition that was taken to reach a particular point in the program. /// When two branches are merged back into one, this constitutes a join point, and is analogous /// to the rest of the program after an if statement. When such a join point / end block is @@ -197,6 +207,7 @@ struct Branch { condition: ValueId, last_block: BasicBlockId, store_values: HashMap, + local_allocations: HashSet, } fn flatten_function_cfg(function: &mut Function) { @@ -211,10 +222,12 @@ fn flatten_function_cfg(function: &mut Function) { } let cfg = ControlFlowGraph::with_function(function); let branch_ends = branch_analysis::find_branch_ends(function, &cfg); + let mut context = Context { inserter: FunctionInserter::new(function), cfg, store_values: HashMap::new(), + local_allocations: HashSet::new(), branch_ends, conditions: Vec::new(), }; @@ -359,40 +372,60 @@ impl<'f> Context<'f> { Type::Numeric(_) => { self.merge_numeric_values(then_condition, else_condition, then_value, else_value) } - Type::Array(element_types, len) => { - let mut merged = im::Vector::new(); - - for i in 0..len { - for (element_index, element_type) in element_types.iter().enumerate() { - let index = ((i * element_types.len() + element_index) as u128).into(); - let index = self.inserter.function.dfg.make_constant(index, Type::field()); - - let typevars = Some(vec![element_type.clone()]); - - let mut get_element = |array, typevars| { - let get = Instruction::ArrayGet { array, index }; - self.insert_instruction_with_typevars(get, typevars).first() - }; - - let then_element = get_element(then_value, typevars.clone()); - let else_element = get_element(else_value, typevars); - - merged.push_back(self.merge_values( - then_condition, - else_condition, - then_element, - else_element, - )); - } - } - - self.inserter.function.dfg.make_array(merged, element_types) - } + Type::Array(element_types, len) => self.merge_array_values( + element_types, + len, + then_condition, + else_condition, + then_value, + else_value, + ), Type::Reference => panic!("Cannot return references from an if expression"), Type::Function => panic!("Cannot return functions from an if expression"), } } + /// Given an if expression that returns an array: `if c { array1 } else { array2 }`, + /// this function will recursively merge array1 and array2 into a single resulting array + /// by creating a new array containing the result of self.merge_values for each element. + fn merge_array_values( + &mut self, + element_types: Rc, + len: usize, + then_condition: ValueId, + else_condition: ValueId, + then_value: ValueId, + else_value: ValueId, + ) -> ValueId { + let mut merged = im::Vector::new(); + + for i in 0..len { + for (element_index, element_type) in element_types.iter().enumerate() { + let index = ((i * element_types.len() + element_index) as u128).into(); + let index = self.inserter.function.dfg.make_constant(index, Type::field()); + + let typevars = Some(vec![element_type.clone()]); + + let mut get_element = |array, typevars| { + let get = Instruction::ArrayGet { array, index }; + self.insert_instruction_with_typevars(get, typevars).first() + }; + + let then_element = get_element(then_value, typevars.clone()); + let else_element = get_element(else_value, typevars); + + merged.push_back(self.merge_values( + then_condition, + else_condition, + then_element, + else_element, + )); + } + } + + self.inserter.function.dfg.make_array(merged, element_types) + } + /// Merge two numeric values a and b from separate basic blocks to a single value. This /// function would return the result of `if c { a } else { b }` as `c*a + (!c)*b`. fn merge_numeric_values( @@ -437,13 +470,18 @@ impl<'f> Context<'f> { // 'else' case of an if with no else - so there is no else branch. Branch { condition: new_condition, + // The last block here is somewhat arbitrary. It only matters that it has no Jmp + // args that will be merged by inline_branch_end. Since jmpifs don't have + // block arguments, it is safe to use the jmpif block here. last_block: jmpif_block, store_values: HashMap::new(), + local_allocations: HashSet::new(), } } else { self.push_condition(jmpif_block, new_condition); self.insert_current_side_effects_enabled(); let old_stores = std::mem::take(&mut self.store_values); + let old_allocations = std::mem::take(&mut self.local_allocations); // Remember the old condition value is now known to be true/false within this branch let known_value = @@ -453,12 +491,15 @@ impl<'f> Context<'f> { let final_block = self.inline_block(destination, &[]); self.conditions.pop(); + let stores_in_branch = std::mem::replace(&mut self.store_values, old_stores); + let local_allocations = std::mem::replace(&mut self.local_allocations, old_allocations); Branch { condition: new_condition, last_block: final_block, store_values: stores_in_branch, + local_allocations, } } } @@ -533,14 +574,16 @@ impl<'f> Context<'f> { } fn remember_store(&mut self, address: ValueId, new_value: ValueId) { - if let Some(store_value) = self.store_values.get_mut(&address) { - store_value.new_value = new_value; - } else { - let load = Instruction::Load { address }; - let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]); - let old_value = self.insert_instruction_with_typevars(load, load_type).first(); + if !self.local_allocations.contains(&address) { + if let Some(store_value) = self.store_values.get_mut(&address) { + store_value.new_value = new_value; + } else { + let load = Instruction::Load { address }; + let load_type = Some(vec![self.inserter.function.dfg.type_of_value(new_value)]); + let old_value = self.insert_instruction_with_typevars(load, load_type).first(); - self.store_values.insert(address, Store { old_value, new_value }); + self.store_values.insert(address, Store { old_value, new_value }); + } } } @@ -558,6 +601,7 @@ impl<'f> Context<'f> { // If this is not a separate variable, clippy gets confused and says the to_vec is // unnecessary, when removing it actually causes an aliasing/mutability error. let instructions = self.inserter.function.dfg[destination].instructions().to_vec(); + for instruction in instructions { self.push_instruction(instruction); } @@ -574,8 +618,16 @@ impl<'f> Context<'f> { fn push_instruction(&mut self, id: InstructionId) { let instruction = self.inserter.map_instruction(id); let instruction = self.handle_instruction_side_effects(instruction); + let is_allocate = matches!(instruction, Instruction::Allocate); + let entry = self.inserter.function.entry_block(); - self.inserter.push_instruction_value(instruction, id, entry); + let results = self.inserter.push_instruction_value(instruction, id, entry); + + // Remember an allocate was created local to this branch so that we do not try to merge store + // values across branches for it later. + if is_allocate { + self.local_allocations.insert(results.first()); + } } /// If we are currently in a branch, we need to modify constrain instructions @@ -1020,6 +1072,84 @@ mod test { assert_eq!(merged_values, vec![3, 5, 6]); } + #[test] + fn allocate_in_single_branch() { + // Regression test for #1756 + // fn foo() -> Field { + // let mut x = 0; + // x + // } + // + // fn main(cond:bool) { + // if cond { + // foo(); + // }; + // } + // + // // Translates to the following before the flattening pass: + // fn main f2 { + // b0(v0: u1): + // jmpif v0 then: b1, else: b2 + // b1(): + // v2 = allocate + // store Field 0 at v2 + // v4 = load v2 + // jmp b2() + // b2(): + // return + // } + // The bug is that the flattening pass previously inserted a load + // before the first store to allocate, which loaded an uninitialized value. + // In this test we assert the ordering is strictly Allocate then Store then Load. + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id, RuntimeType::Acir); + + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + + let v0 = builder.add_parameter(Type::bool()); + builder.terminate_with_jmpif(v0, b1, b2); + + builder.switch_to_block(b1); + let v2 = builder.insert_allocate(); + let zero = builder.field_constant(0u128); + builder.insert_store(v2, zero); + let _v4 = builder.insert_load(v2, Type::field()); + builder.terminate_with_jmp(b2, vec![]); + + builder.switch_to_block(b2); + builder.terminate_with_return(vec![]); + + let ssa = builder.finish().flatten_cfg(); + let main = ssa.main(); + + // Now assert that there is not a load between the allocate and its first store + // The Expected IR is: + // + // fn main f2 { + // b0(v0: u1): + // enable_side_effects v0 + // v6 = allocate + // store Field 0 at v6 + // v7 = load v6 + // v8 = not v0 + // enable_side_effects u1 1 + // return + // } + let instructions = main.dfg[main.entry_block()].instructions(); + + let find_instruction = |predicate: fn(&Instruction) -> bool| { + instructions.iter().position(|id| predicate(&main.dfg[*id])).unwrap() + }; + + let allocate_index = find_instruction(|i| matches!(i, Instruction::Allocate)); + let store_index = find_instruction(|i| matches!(i, Instruction::Store { .. })); + let load_index = find_instruction(|i| matches!(i, Instruction::Load { .. })); + + assert!(allocate_index < store_index); + assert!(store_index < load_index); + } + /// Work backwards from an instruction to find all the constant values /// that were used to construct it. E.g for: ///