diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 13e1e181dec..9c60be37bc9 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -61,12 +61,44 @@ //! SSA optimization pipeline, although it will be more successful the simpler the program's CFG is. //! This pass is currently performed several times to enable other passes - most notably being //! performed before loop unrolling to try to allow for mutable variables used for loop indices. +//! +//! As stated above, the algorithm above can sometimes miss known references. +//! This most commonly occurs in the case of loops, where we may have allocations preceding a loop that are known, +//! but the loop body's blocks are predecessors to the loop header block, causing those known allocations to be marked unknown. +//! In certain cases we may be able to remove these allocations that precede a loop. +//! For example, if a reference is not stored to again in the loop we should be able to remove that store which precedes the loop. +//! +//! To handle cases such as the one laid out above, we maintain some extra state per function, +//! that we will analyze after the initial run through all of the blocks. +//! We refer to this as the "function cleanup" and it requires having already iterated through all blocks. +//! +//! The state contains the following: +//! - For each load address we store the number of loads from a given address, +//! the last load instruction from a given address across all blocks, and the respective block id of that instruction. +//! - A mapping of each load result to its number of uses, the load instruction that produced the given result, and the respective block id of that instruction. +//! - A set of the references and their aliases passed as an argument to a call. +//! - Maps the references which have been aliased to the instructions that aliased that reference. +//! - As we go through each instruction, if a load result has been used we increment its usage counter. +//! Upon removing an instruction, we decrement the load result counter. +//! After analyzing all of a function's blocks we can analyze the per function state: +//! - If we find that a load result's usage counter equals zero, we can remove that load. +//! - We can then remove a store if the following conditions are met: +//! - All loads to a given address have been removed +//! - None of the aliases of a reference are used in any of the following: +//! - Block parameters, function parameters, call arguments, terminator arguments +//! - The store address is not aliased. +//! - If a store is in a return block, we can have special handling that only checks if there is a load after +//! that store in the return block. In the case of a return block, even if there are other loads +//! in preceding blocks we can safely remove those stores. +//! - To further catch any stores to references which are never loaded, we can count the number of stores +//! that were removed in the previous step. If there is only a single store leftover, we can safely map +//! the value of this final store to any loads of that store. mod alias_set; mod block; use std::collections::{BTreeMap, BTreeSet}; -use fxhash::FxHashMap as HashMap; +use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; use crate::ssa::{ ir::{ @@ -96,6 +128,7 @@ impl Ssa { context.remove_instructions(); context.update_data_bus(); } + self } } @@ -112,11 +145,58 @@ struct PerFunctionContext<'f> { /// /// We avoid removing individual instructions as we go since removing elements /// from the middle of Vecs many times will be slower than a single call to `retain`. - instructions_to_remove: BTreeSet, + instructions_to_remove: HashSet, /// Track a value's last load across all blocks. /// If a value is not used in anymore loads we can remove the last store to that value. - last_loads: HashMap, + last_loads: HashMap, + + /// Track whether a load result was used across all blocks. + load_results: HashMap, + + /// Track whether a reference was passed into another entry point + /// This is needed to determine whether we can remove a store. + calls_reference_input: HashSet, + + /// Track whether a reference has been aliased, and store the respective + /// instruction that aliased that reference. + /// If that store has been set for removal, we can also remove this instruction. + aliased_references: HashMap>, + + // The index of the last load instruction in a given block + return_block_load_locations: HashMap<(ValueId, BasicBlockId), usize>, +} + +#[derive(Debug, Clone)] +struct PerFuncLastLoadContext { + /// Reference counter that keeps track of how many times we loaded from a given address + num_loads: u32, + /// Last load instruction from a given address + load_instruction: InstructionId, + /// Block of the last load instruction + block_id: BasicBlockId, +} + +impl PerFuncLastLoadContext { + fn new(load_instruction: InstructionId, block_id: BasicBlockId, num_loads: u32) -> Self { + Self { num_loads, load_instruction, block_id } + } +} + +#[derive(Debug, Clone)] +struct PerFuncLoadResultContext { + /// Reference counter that keeps track of how many times a load was used in other instructions + uses: u32, + /// Load instruction that produced a given load result + load_instruction: InstructionId, + /// Block of the load instruction that produced a given result + block_id: BasicBlockId, +} + +impl PerFuncLoadResultContext { + fn new(load_instruction: InstructionId, block_id: BasicBlockId) -> Self { + Self { uses: 0, load_instruction, block_id } + } } impl<'f> PerFunctionContext<'f> { @@ -129,8 +209,12 @@ impl<'f> PerFunctionContext<'f> { post_order, inserter: FunctionInserter::new(function), blocks: BTreeMap::new(), - instructions_to_remove: BTreeSet::new(), + instructions_to_remove: HashSet::default(), last_loads: HashMap::default(), + load_results: HashMap::default(), + calls_reference_input: HashSet::default(), + aliased_references: HashMap::default(), + return_block_load_locations: HashMap::default(), } } @@ -148,44 +232,7 @@ impl<'f> PerFunctionContext<'f> { self.analyze_block(block, references); } - // If we never load from an address within a function we can remove all stores to that address. - // This rule does not apply to reference parameters, which we must also check for before removing these stores. - for (block_id, block) in self.blocks.iter() { - let block_params = self.inserter.function.dfg.block_parameters(*block_id); - for (store_address, store_instruction) in block.last_stores.iter() { - let is_reference_param = block_params.contains(store_address); - let terminator = self.inserter.function.dfg[*block_id].unwrap_terminator(); - - let is_return = matches!(terminator, TerminatorInstruction::Return { .. }); - let remove_load = if is_return { - // Determine whether the last store is used in the return value - let mut is_return_value = false; - terminator.for_each_value(|return_value| { - is_return_value = return_value == *store_address || is_return_value; - }); - - // If the last load of a store is not part of the block with a return terminator, - // we can safely remove this store. - let last_load_not_in_return = self - .last_loads - .get(store_address) - .map(|(_, last_load_block)| *last_load_block != *block_id) - .unwrap_or(true); - !is_return_value && last_load_not_in_return - } else { - self.last_loads.get(store_address).is_none() - }; - - let is_reference_alias = block - .expressions - .get(store_address) - .map_or(false, |expression| matches!(expression, Expression::Dereference(_))); - - if remove_load && !is_reference_param && !is_reference_alias { - self.instructions_to_remove.insert(*store_instruction); - } - } - } + self.cleanup_function(); } /// The value of each reference at the start of the given block is the unification @@ -250,12 +297,27 @@ impl<'f> PerFunctionContext<'f> { // If `allocation_aliases_parameter` is known to be false if allocation_aliases_parameter == Some(false) { self.instructions_to_remove.insert(*instruction); + if let Some(context) = self.load_results.get_mut(allocation) { + context.uses -= 1; + } } } } } } + fn increase_load_ref_counts(&mut self, value: ValueId) { + if let Some(context) = self.load_results.get_mut(&value) { + context.uses += 1; + } + let array_const = self.inserter.function.dfg.get_array_constant(value); + if let Some((values, _)) = array_const { + for array_value in values { + self.increase_load_ref_counts(array_value); + } + } + } + fn analyze_instruction( &mut self, block_id: BasicBlockId, @@ -271,6 +333,16 @@ impl<'f> PerFunctionContext<'f> { return; } + let mut collect_values = Vec::new(); + // Track whether any load results were used in the instruction + self.inserter.function.dfg[instruction].for_each_value(|value| { + collect_values.push(value); + }); + + for value in collect_values { + self.increase_load_ref_counts(value); + } + match &self.inserter.function.dfg[instruction] { Instruction::Load { address } => { let address = self.inserter.function.dfg.resolve(*address); @@ -280,7 +352,6 @@ impl<'f> PerFunctionContext<'f> { // If the load is known, replace it with the known value and remove the load if let Some(value) = references.get_known_value(address) { - let result = self.inserter.function.dfg.instruction_results(instruction)[0]; self.inserter.map_value(result, value); self.instructions_to_remove.insert(instruction); } else { @@ -301,7 +372,23 @@ impl<'f> PerFunctionContext<'f> { // Mark that we know a load result is equivalent to the address of a load. references.set_known_value(result, address); - self.last_loads.insert(address, (instruction, block_id)); + self.load_results + .insert(result, PerFuncLoadResultContext::new(instruction, block_id)); + + let num_loads = + self.last_loads.get(&address).map_or(1, |context| context.num_loads + 1); + let last_load = PerFuncLastLoadContext::new(instruction, block_id, num_loads); + self.last_loads.insert(address, last_load); + + // If we are in a return block we want to save the last location of a load + let terminator = self.inserter.function.dfg[block_id].unwrap_terminator(); + let is_return = matches!(terminator, TerminatorInstruction::Return { .. }); + if is_return { + let instruction_index = + self.inserter.function.dfg[block_id].instructions().len(); + self.return_block_load_locations + .insert((address, block_id), instruction_index); + } } } Instruction::Store { address, value } => { @@ -310,10 +397,21 @@ impl<'f> PerFunctionContext<'f> { self.check_array_aliasing(references, value); - // If there was another store to this instruction without any (unremoved) loads or + // If there was another store to this address without any (unremoved) loads or // function calls in-between, we can remove the previous store. if let Some(last_store) = references.last_stores.get(&address) { self.instructions_to_remove.insert(*last_store); + let Instruction::Store { address, value } = + self.inserter.function.dfg[*last_store] + else { + panic!("Should have a store instruction here"); + }; + if let Some(context) = self.load_results.get_mut(&address) { + context.uses -= 1; + } + if let Some(context) = self.load_results.get_mut(&value) { + context.uses -= 1; + } } let known_value = references.get_known_value(value); @@ -321,11 +419,33 @@ impl<'f> PerFunctionContext<'f> { let known_value_is_address = known_value == address; if known_value_is_address { self.instructions_to_remove.insert(instruction); + if let Some(context) = self.load_results.get_mut(&address) { + context.uses -= 1; + } + if let Some(context) = self.load_results.get_mut(&value) { + context.uses -= 1; + } + } else { + references.last_stores.insert(address, instruction); + } + } else { + references.last_stores.insert(address, instruction); + } + + if self.inserter.function.dfg.value_is_reference(value) { + if let Some(expression) = references.expressions.get(&value) { + if let Some(aliases) = references.aliases.get(expression) { + aliases.for_each(|alias| { + self.aliased_references + .entry(alias) + .or_default() + .insert(instruction); + }); + } } } references.set_known_value(address, value); - references.last_stores.insert(address, instruction); } Instruction::Allocate => { // Register the new reference @@ -375,7 +495,20 @@ impl<'f> PerFunctionContext<'f> { references.aliases.insert(expression, aliases); } } - Instruction::Call { arguments, .. } => self.mark_all_unknown(arguments, references), + Instruction::Call { arguments, .. } => { + for arg in arguments { + if self.inserter.function.dfg.value_is_reference(*arg) { + if let Some(expression) = references.expressions.get(arg) { + if let Some(aliases) = references.aliases.get(expression) { + aliases.for_each(|alias| { + self.calls_reference_input.insert(alias); + }); + } + } + } + } + self.mark_all_unknown(arguments, references); + } _ => (), } } @@ -443,7 +576,20 @@ impl<'f> PerFunctionContext<'f> { fn handle_terminator(&mut self, block: BasicBlockId, references: &mut Block) { self.inserter.map_terminator_in_place(block); - match self.inserter.function.dfg[block].unwrap_terminator() { + let terminator: &TerminatorInstruction = + self.inserter.function.dfg[block].unwrap_terminator(); + + let mut collect_values = Vec::new(); + terminator.for_each_value(|value| { + collect_values.push(value); + }); + + let terminator = terminator.clone(); + for value in collect_values.iter() { + self.increase_load_ref_counts(*value); + } + + match &terminator { TerminatorInstruction::JmpIf { .. } => (), // Nothing to do TerminatorInstruction::Jmp { destination, arguments, .. } => { let destination_parameters = self.inserter.function.dfg[*destination].parameters(); @@ -471,6 +617,303 @@ impl<'f> PerFunctionContext<'f> { } } } + + fn recursively_add_values(&self, value: ValueId, set: &mut HashSet) { + set.insert(value); + if let Some((elements, _)) = self.inserter.function.dfg.get_array_constant(value) { + for array_element in elements { + self.recursively_add_values(array_element, set); + } + } + } + + /// The mem2reg pass is sometimes unable to determine certain known values + /// when iterating over a function's block in reverse post order. + /// We collect state about any final loads and stores to a given address during the initial mem2reg pass. + /// We can then utilize this state to clean up any loads and stores that may have been missed. + fn cleanup_function(&mut self) { + // Removing remaining unused loads during mem2reg can help expose removable stores that the initial + // mem2reg pass deemed we could not remove due to the existence of those unused loads. + let removed_loads = self.remove_unused_loads(); + let remaining_last_stores = self.remove_unloaded_last_stores(&removed_loads); + let stores_were_removed = + self.remove_remaining_last_stores(&removed_loads, &remaining_last_stores); + + // When removing some last loads with the last stores we will map the load result to the store value. + // We need to then map all the instructions again as we do not know which instructions are reliant on the load result. + if stores_were_removed { + let mut block_order = PostOrder::with_function(self.inserter.function).into_vec(); + block_order.reverse(); + for block in block_order { + let instructions = self.inserter.function.dfg[block].take_instructions(); + for instruction in instructions { + if !self.instructions_to_remove.contains(&instruction) { + self.inserter.push_instruction(instruction, block); + } + } + self.inserter.map_terminator_in_place(block); + } + } + } + + /// Cleanup remaining loads across the entire function + /// Remove any loads whose reference counter is zero. + /// Returns a map of the removed load address to the number of load instructions removed for that address + fn remove_unused_loads(&mut self) -> HashMap { + let mut removed_loads = HashMap::default(); + for (_, PerFuncLoadResultContext { uses, load_instruction, block_id, .. }) in + self.load_results.iter() + { + let Instruction::Load { address } = self.inserter.function.dfg[*load_instruction] + else { + unreachable!("Should only have a load instruction here"); + }; + // If the load result's counter is equal to zero we can safely remove that load instruction. + if *uses == 0 { + self.return_block_load_locations.remove(&(address, *block_id)); + + removed_loads.entry(address).and_modify(|counter| *counter += 1).or_insert(1); + self.instructions_to_remove.insert(*load_instruction); + } + } + removed_loads + } + + fn recursively_check_address_in_terminator( + &self, + return_value: ValueId, + store_address: ValueId, + is_return_value: &mut bool, + ) { + *is_return_value = return_value == store_address || *is_return_value; + let array_const = self.inserter.function.dfg.get_array_constant(return_value); + if let Some((values, _)) = array_const { + for array_value in values { + self.recursively_check_address_in_terminator( + array_value, + store_address, + is_return_value, + ); + } + } + } + + /// Cleanup remaining stores across the entire function. + /// If we never load from an address within a function we can remove all stores to that address. + /// This rule does not apply to reference parameters, which we must also check for before removing these stores. + /// Returns a map of any remaining stores which may still have loads in use. + fn remove_unloaded_last_stores( + &mut self, + removed_loads: &HashMap, + ) -> HashMap { + let mut all_terminator_values = HashSet::default(); + let mut per_func_block_params: HashSet = HashSet::default(); + for (block_id, _) in self.blocks.iter() { + let block_params = self.inserter.function.dfg.block_parameters(*block_id); + per_func_block_params.extend(block_params.iter()); + + let terminator = self.inserter.function.dfg[*block_id].unwrap_terminator(); + terminator.for_each_value(|value| { + self.recursively_add_values(value, &mut all_terminator_values); + }); + } + + let mut remaining_last_stores: HashMap = HashMap::default(); + for (block_id, block) in self.blocks.iter() { + for (store_address, store_instruction) in block.last_stores.iter() { + if self.instructions_to_remove.contains(store_instruction) { + continue; + } + + let all_loads_removed = self.all_loads_removed_for_address( + store_address, + *store_instruction, + *block_id, + removed_loads, + ); + + let store_alias_used = self.is_store_alias_used( + store_address, + block, + &all_terminator_values, + &per_func_block_params, + ); + + if all_loads_removed && !store_alias_used { + self.instructions_to_remove.insert(*store_instruction); + if let Some((_, counter)) = remaining_last_stores.get_mut(store_address) { + *counter -= 1; + } + } else if let Some((_, counter)) = remaining_last_stores.get_mut(store_address) { + *counter += 1; + } else { + remaining_last_stores.insert(*store_address, (*store_instruction, 1)); + } + } + } + remaining_last_stores + } + + fn all_loads_removed_for_address( + &self, + store_address: &ValueId, + store_instruction: InstructionId, + block_id: BasicBlockId, + removed_loads: &HashMap, + ) -> bool { + let terminator = self.inserter.function.dfg[block_id].unwrap_terminator(); + let is_return = matches!(terminator, TerminatorInstruction::Return { .. }); + // Determine whether any loads that reference this store address + // have been removed while cleaning up unused loads. + if is_return { + // If we are in a return terminator, and the last loads of a reference + // come before a store to that reference, we can safely remove that store. + let store_after_load = if let Some(max_load_index) = + self.return_block_load_locations.get(&(*store_address, block_id)) + { + let store_index = self.inserter.function.dfg[block_id] + .instructions() + .iter() + .position(|id| *id == store_instruction) + .expect("Store instruction should exist in the return block"); + store_index > *max_load_index + } else { + // Otherwise there is no load in this block + true + }; + store_after_load + } else if let (Some(context), Some(loads_removed_counter)) = + (self.last_loads.get(store_address), removed_loads.get(store_address)) + { + // `last_loads` contains the total number of loads for a given load address + // If the number of removed loads for a given address is equal to the total number of loads for that address, + // we know we can safely remove any stores to that load address. + context.num_loads == *loads_removed_counter + } else { + self.last_loads.get(store_address).is_none() + } + } + + // Extra checks on where a reference can be used aside a load instruction. + // Even if all loads to a reference have been removed we need to make sure that + // an allocation did not come from an entry point or was passed to an entry point. + fn is_store_alias_used( + &self, + store_address: &ValueId, + block: &Block, + all_terminator_values: &HashSet, + per_func_block_params: &HashSet, + ) -> bool { + let func_params = self.inserter.function.parameters(); + let reference_parameters = func_params + .iter() + .filter(|param| self.inserter.function.dfg.value_is_reference(**param)) + .collect::>(); + + let mut store_alias_used = false; + if let Some(expression) = block.expressions.get(store_address) { + if let Some(aliases) = block.aliases.get(expression) { + let allocation_aliases_parameter = + aliases.any(|alias| reference_parameters.contains(&alias)); + if allocation_aliases_parameter == Some(true) { + store_alias_used = true; + } + + let allocation_aliases_parameter = + aliases.any(|alias| per_func_block_params.contains(&alias)); + if allocation_aliases_parameter == Some(true) { + store_alias_used = true; + } + + let allocation_aliases_parameter = + aliases.any(|alias| self.calls_reference_input.contains(&alias)); + if allocation_aliases_parameter == Some(true) { + store_alias_used = true; + } + + let allocation_aliases_parameter = + aliases.any(|alias| all_terminator_values.contains(&alias)); + if allocation_aliases_parameter == Some(true) { + store_alias_used = true; + } + + let allocation_aliases_parameter = aliases.any(|alias| { + if let Some(alias_instructions) = self.aliased_references.get(&alias) { + self.instructions_to_remove.is_disjoint(alias_instructions) + } else { + false + } + }); + if allocation_aliases_parameter == Some(true) { + store_alias_used = true; + } + } + } + + store_alias_used + } + + /// Check if any remaining last stores are only used in a single load + /// Returns true if any stores were removed. + fn remove_remaining_last_stores( + &mut self, + removed_loads: &HashMap, + remaining_last_stores: &HashMap, + ) -> bool { + let mut stores_were_removed = false; + // Filter out any still in use load results and any load results that do not contain addresses from the remaining last stores + self.load_results.retain(|_, PerFuncLoadResultContext { load_instruction, uses, .. }| { + let Instruction::Load { address } = self.inserter.function.dfg[*load_instruction] + else { + unreachable!("Should only have a load instruction here"); + }; + remaining_last_stores.contains_key(&address) && *uses > 0 + }); + + for (store_address, (store_instruction, store_counter)) in remaining_last_stores { + let Instruction::Store { value, .. } = self.inserter.function.dfg[*store_instruction] + else { + unreachable!("Should only have a store instruction"); + }; + + if let (Some(context), Some(loads_removed_counter)) = + (self.last_loads.get(store_address), removed_loads.get(store_address)) + { + assert!( + context.num_loads >= *loads_removed_counter, + "The number of loads removed should not be more than all loads" + ); + } + + // We only want to remove last stores referencing a single address. + if *store_counter != 0 { + continue; + } + + self.instructions_to_remove.insert(*store_instruction); + + // Map any remaining load results to the value from the removed store + for (result, context) in self.load_results.iter() { + let Instruction::Load { address } = + self.inserter.function.dfg[context.load_instruction] + else { + unreachable!("Should only have a load instruction here"); + }; + if address != *store_address { + continue; + } + + // Map the load result to its respective store value + // We will have to map all instructions following this method + // as we do not know what instructions depend upon this result + self.inserter.map_value(*result, value); + self.instructions_to_remove.insert(context.load_instruction); + + stores_were_removed = true; + } + } + stores_were_removed + } } #[cfg(test)] @@ -756,7 +1199,7 @@ mod tests { // return // } let ssa = ssa.mem2reg(); - + println!("{}", ssa); let main = ssa.main(); assert_eq!(main.reachable_blocks().len(), 2); @@ -778,6 +1221,163 @@ mod tests { assert_eq!(b1_instructions.len(), 0); } + #[test] + fn remove_unused_loads_and_stores() { + // acir(inline) fn main f0 { + // b0(): + // v0 = allocate + // store Field 1 at v0 + // v2 = allocate + // store Field 1 at v2 + // v4 = allocate + // store u1 0 at v4 + // v5 = allocate + // store u1 0 at v5 + // v6 = allocate + // store u1 0 at v6 + // jmp b1(u1 0) + // b1(v7: u32): + // v9 = eq v7, u32 0 + // jmpif v9 then: b3, else: b2 + // b3(): + // v20 = load v0 + // v21 = load v2 + // v22 = load v4 + // v23 = load v5 + // v24 = load v6 + // constrain v20 == Field 1 + // v25 = eq v21, Field 1 + // constrain v21 == Field 1 + // v26 = eq v7, u32 0 + // jmp b1(v26) + // b2(): + // v10 = load v0 + // v11 = load v2 + // v12 = load v4 + // v13 = load v5 + // v14 = load v6 + // store Field 1 at v0 + // store Field 1 at v2 + // store v12 at v4 + // store v13 at v5 + // store v14 at v6 + // v15 = load v0 + // v16 = load v2 + // v17 = load v4 + // v18 = load v5 + // v19 = load v6 + // constrain v15 == Field 1 + // return v16 + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id); + + let v0 = builder.insert_allocate(Type::field()); + let one = builder.numeric_constant(1u128, Type::field()); + builder.insert_store(v0, one); + + let v2 = builder.insert_allocate(Type::field()); + builder.insert_store(v2, one); + + let zero_bool = builder.numeric_constant(0u128, Type::bool()); + let v4 = builder.insert_allocate(Type::bool()); + builder.insert_store(v4, zero_bool); + + let v6 = builder.insert_allocate(Type::bool()); + builder.insert_store(v6, zero_bool); + + let v8 = builder.insert_allocate(Type::bool()); + builder.insert_store(v8, zero_bool); + + let b1 = builder.insert_block(); + builder.terminate_with_jmp(b1, vec![zero_bool]); + + builder.switch_to_block(b1); + + let v7 = builder.add_block_parameter(b1, Type::unsigned(32)); + let zero_u32 = builder.numeric_constant(0u128, Type::unsigned(32)); + let is_zero = builder.insert_binary(v7, BinaryOp::Eq, zero_u32); + + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + builder.terminate_with_jmpif(is_zero, b3, b2); + + builder.switch_to_block(b2); + + let _ = builder.insert_load(v0, Type::field()); + let _ = builder.insert_load(v2, Type::field()); + let v12 = builder.insert_load(v4, Type::bool()); + let v13 = builder.insert_load(v6, Type::bool()); + let v14 = builder.insert_load(v8, Type::bool()); + + builder.insert_store(v0, one); + builder.insert_store(v2, one); + builder.insert_store(v4, v12); + builder.insert_store(v6, v13); + builder.insert_store(v8, v14); + + let v15 = builder.insert_load(v0, Type::field()); + // Insert unused loads + let v16 = builder.insert_load(v2, Type::field()); + let _ = builder.insert_load(v4, Type::bool()); + let _ = builder.insert_load(v6, Type::bool()); + let _ = builder.insert_load(v8, Type::bool()); + + builder.insert_constrain(v15, one, None); + builder.terminate_with_return(vec![v16]); + + builder.switch_to_block(b3); + + let v26 = builder.insert_load(v0, Type::field()); + // Insert unused loads + let v27 = builder.insert_load(v2, Type::field()); + let _ = builder.insert_load(v4, Type::bool()); + let _ = builder.insert_load(v6, Type::bool()); + let _ = builder.insert_load(v8, Type::bool()); + + builder.insert_constrain(v26, one, None); + let _ = builder.insert_binary(v27, BinaryOp::Eq, one); + builder.insert_constrain(v27, one, None); + let one_u32 = builder.numeric_constant(0u128, Type::unsigned(32)); + let plus_one = builder.insert_binary(v7, BinaryOp::Eq, one_u32); + builder.terminate_with_jmp(b1, vec![plus_one]); + + let ssa = builder.finish(); + + // Expected result: + // acir(inline) fn main f0 { + // b0(): + // v27 = allocate + // v28 = allocate + // v29 = allocate + // v30 = allocate + // v31 = allocate + // jmp b1(u1 0) + // b1(v7: u32): + // v32 = eq v7, u32 0 + // jmpif v32 then: b3, else: b2 + // b3(): + // v49 = eq v7, u32 0 + // jmp b1(v49) + // b2(): + // return Field 1 + // } + let ssa = ssa.mem2reg(); + + let main = ssa.main(); + assert_eq!(main.reachable_blocks().len(), 4); + + // All loads should be removed + assert_eq!(count_loads(b2, &main.dfg), 0); + assert_eq!(count_loads(b3, &main.dfg), 0); + + // All stores should be removed + assert_eq!(count_stores(main.entry_block(), &main.dfg), 0); + assert_eq!(count_stores(b2, &main.dfg), 0); + // Should only have one instruction in b3 + assert_eq!(main.dfg[b3].instructions().len(), 1); + } + #[test] fn keep_store_to_alias_in_loop_block() { // This test makes sure the instruction `store Field 2 at v5` in b2 remains after mem2reg. @@ -871,4 +1471,240 @@ mod tests { assert_eq!(count_loads(b2, &main.dfg), 1); assert_eq!(count_loads(b3, &main.dfg), 3); } + + #[test] + fn accurate_tracking_of_load_results() { + // acir(inline) fn main f0 { + // b0(): + // v0 = allocate + // store Field 5 at v0 + // v2 = allocate + // store u32 10 at v2 + // v4 = load v0 + // v5 = load v2 + // v6 = allocate + // store v4 at v6 + // v7 = allocate + // store v5 at v7 + // v8 = load v6 + // v9 = load v7 + // v10 = load v6 + // v11 = load v7 + // v12 = allocate + // store Field 0 at v12 + // v15 = eq v11, u32 0 + // jmpif v15 then: b1, else: b2 + // b1(): + // v16 = load v12 + // v17 = add v16, v8 + // store v17 at v12 + // jmp b2() + // b2(): + // v18 = load v12 + // return [v18] + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id); + + let v0 = builder.insert_allocate(Type::field()); + let five = builder.numeric_constant(5u128, Type::field()); + builder.insert_store(v0, five); + + let v2 = builder.insert_allocate(Type::unsigned(32)); + let ten = builder.numeric_constant(10u128, Type::unsigned(32)); + builder.insert_store(v2, ten); + + let v4 = builder.insert_load(v0, Type::field()); + let v5 = builder.insert_load(v2, Type::unsigned(32)); + let v4_type = builder.current_function.dfg.type_of_value(v4); + let v5_type = builder.current_function.dfg.type_of_value(v5); + + let v6 = builder.insert_allocate(Type::field()); + builder.insert_store(v6, v4); + let v7 = builder.insert_allocate(Type::unsigned(32)); + builder.insert_store(v7, v5); + + let v8 = builder.insert_load(v6, v4_type.clone()); + let _v9 = builder.insert_load(v7, v5_type.clone()); + + let _v10 = builder.insert_load(v6, v4_type); + let v11 = builder.insert_load(v7, v5_type); + + let v12 = builder.insert_allocate(Type::field()); + let zero = builder.numeric_constant(0u128, Type::field()); + builder.insert_store(v12, zero); + + let zero_u32 = builder.numeric_constant(0u128, Type::unsigned(32)); + let v15 = builder.insert_binary(v11, BinaryOp::Eq, zero_u32); + + let b1 = builder.insert_block(); + let b2 = builder.insert_block(); + builder.terminate_with_jmpif(v15, b1, b2); + + builder.switch_to_block(b1); + + let v16 = builder.insert_load(v12, Type::field()); + let v17 = builder.insert_binary(v16, BinaryOp::Add, v8); + builder.insert_store(v12, v17); + + builder.terminate_with_jmp(b2, vec![]); + + builder.switch_to_block(b2); + let v18 = builder.insert_load(v12, Type::field()); + + // Include the load result as part of an array constant to check that we are accounting for arrays + // when updating the reference counts of load results. + // + // If we were not accounting for arrays appropriately, the load of v18 would be removed. + // If v18 is the last load of a reference and is inadvertently removed, + // any stores to v12 will then be potentially removed as well and the program will be broken. + let return_array = + builder.array_constant(vector![v18], Type::Array(Arc::new(vec![Type::field()]), 1)); + builder.terminate_with_return(vec![return_array]); + + let ssa = builder.finish(); + + // Expected result: + // acir(inline) fn main f0 { + // b0(): + // v20 = allocate + // v21 = allocate + // v24 = allocate + // v25 = allocate + // v30 = allocate + // store Field 0 at v30 + // jmpif u1 0 then: b1, else: b2 + // b1(): + // store Field 5 at v30 + // jmp b2() + // b2(): + // v33 = load v30 + // return [v33] + // } + let ssa = ssa.mem2reg(); + + let main = ssa.main(); + assert_eq!(main.reachable_blocks().len(), 3); + + // A single store from the entry block should remain. + // If we are not appropriately handling unused stores across a function, + // we would expect all five stores from the original SSA to remain. + assert_eq!(count_stores(main.entry_block(), &main.dfg), 1); + // The store from the conditional block should remain, + // as it is loaded from in a successor block and used in the return terminator. + assert_eq!(count_stores(b1, &main.dfg), 1); + + assert_eq!(count_loads(main.entry_block(), &main.dfg), 0); + assert_eq!(count_loads(b1, &main.dfg), 0); + assert_eq!(count_loads(b2, &main.dfg), 1); + } + + #[test] + fn keep_unused_store_only_used_as_an_alias_across_blocks() { + // acir(inline) fn main f0 { + // b0(v0: u32): + // v1 = allocate + // store u32 0 at v1 + // v3 = allocate + // store v1 at v3 + // v4 = allocate + // store v0 at v4 + // v5 = allocate + // store v4 at v5 + // jmp b1(u32 0) + // b1(v6: u32): + // v7 = eq v6, u32 0 + // jmpif v7 then: b2, else: b3 + // b2(): + // v8 = load v5 + // store v8 at u2 2 + // v11 = add v6, u32 1 + // jmp b1(v11) + // b3(): + // v12 = load v4 + // constrain v12 == u2 2 + // v13 = load v5 + // v14 = load v13 + // constrain v14 == u2 2 + // v15 = load v3 + // v16 = load v15 + // v18 = lt v16, u32 4 + // constrain v18 == u32 1 + // return + // } + let main_id = Id::test_new(0); + let mut builder = FunctionBuilder::new("main".into(), main_id); + + let v0 = builder.add_parameter(Type::unsigned(32)); + + let v1 = builder.insert_allocate(Type::unsigned(32)); + let zero = builder.numeric_constant(0u128, Type::unsigned(32)); + builder.insert_store(v1, zero); + + let v1_type = builder.type_of_value(v1); + let v3 = builder.insert_allocate(v1_type.clone()); + builder.insert_store(v3, v1); + + let v4 = builder.insert_allocate(Type::unsigned(32)); + builder.insert_store(v4, v0); + + let v5 = builder.insert_allocate(Type::Reference(Arc::new(Type::unsigned(32)))); + builder.insert_store(v5, v4); + + let b1 = builder.insert_block(); + builder.terminate_with_jmp(b1, vec![zero]); + builder.switch_to_block(b1); + + let v6 = builder.add_block_parameter(b1, Type::unsigned(32)); + let is_zero = builder.insert_binary(v6, BinaryOp::Eq, zero); + + let b2 = builder.insert_block(); + let b3 = builder.insert_block(); + builder.terminate_with_jmpif(is_zero, b2, b3); + + builder.switch_to_block(b2); + let v4_type = builder.type_of_value(v4); + // let v0_type = builder.type_of_value(v4); + let v8 = builder.insert_load(v5, v4_type); + let two = builder.numeric_constant(2u128, Type::unsigned(2)); + builder.insert_store(v8, two); + let one = builder.numeric_constant(1u128, Type::unsigned(32)); + let v11 = builder.insert_binary(v6, BinaryOp::Add, one); + builder.terminate_with_jmp(b1, vec![v11]); + + builder.switch_to_block(b3); + + let v12 = builder.insert_load(v4, Type::unsigned(32)); + builder.insert_constrain(v12, two, None); + + let v3_type = builder.type_of_value(v3); + let v13 = builder.insert_load(v5, v3_type); + let v14 = builder.insert_load(v13, Type::unsigned(32)); + builder.insert_constrain(v14, two, None); + + let v15 = builder.insert_load(v3, v1_type); + let v16 = builder.insert_load(v15, Type::unsigned(32)); + let four = builder.numeric_constant(4u128, Type::unsigned(32)); + let less_than_four = builder.insert_binary(v16, BinaryOp::Lt, four); + builder.insert_constrain(less_than_four, one, None); + + builder.terminate_with_return(vec![]); + let ssa = builder.finish(); + + // We expect the same result as above. + let ssa = ssa.mem2reg(); + let main = ssa.main(); + + // We expect all the stores to remain. + // The references in b0 are aliased and those are aliases may never be stored to again, + // but they are loaded from and used in later instructions. + // We need to make sure that the store of the address being aliased, is not removed from the program. + assert_eq!(count_stores(main.entry_block(), &main.dfg), 4); + // The store inside of the loop should remain + assert_eq!(count_stores(b2, &main.dfg), 1); + + // We expect the loads to remain the same + assert_eq!(count_loads(b2, &main.dfg), 1); + assert_eq!(count_loads(b3, &main.dfg), 5); + } } diff --git a/test_programs/compile_success_empty/references_aliasing/src/main.nr b/test_programs/compile_success_empty/references_aliasing/src/main.nr index d3e4257851b..b2b625477e7 100644 --- a/test_programs/compile_success_empty/references_aliasing/src/main.nr +++ b/test_programs/compile_success_empty/references_aliasing/src/main.nr @@ -6,6 +6,7 @@ fn main() { regression_2445(); single_alias_inside_loop(); + assert(5 == struct_field_refs_across_blocks(MyStruct { a: 5, b: 10 })[0]); } fn increment(mut r: &mut Field) { @@ -39,3 +40,20 @@ fn single_alias_inside_loop() { assert(var == 2); assert(**ref == 2); } + +struct MyStruct { + a: Field, + b: u32, +} + +fn struct_field_refs_across_blocks(mut my_struct: MyStruct) -> [Field; 1] { + [compute_dummy_hash(my_struct.a, my_struct.b, 20)] +} + +fn compute_dummy_hash(input: Field, rhs: u32, in_len: u32) -> Field { + let mut res = 0; + if rhs < in_len { + res += input; + } + res +} diff --git a/test_programs/execution_success/reference_only_used_as_alias/Nargo.toml b/test_programs/execution_success/reference_only_used_as_alias/Nargo.toml new file mode 100644 index 00000000000..d7531756822 --- /dev/null +++ b/test_programs/execution_success/reference_only_used_as_alias/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "reference_only_used_as_alias" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/reference_only_used_as_alias/Prover.toml b/test_programs/execution_success/reference_only_used_as_alias/Prover.toml new file mode 100644 index 00000000000..ef34b9eba70 --- /dev/null +++ b/test_programs/execution_success/reference_only_used_as_alias/Prover.toml @@ -0,0 +1,3 @@ +input = [0, 1, 2, 3] +context_input = 4 +randomness = 5 diff --git a/test_programs/execution_success/reference_only_used_as_alias/src/main.nr b/test_programs/execution_success/reference_only_used_as_alias/src/main.nr new file mode 100644 index 00000000000..c04d68d1748 --- /dev/null +++ b/test_programs/execution_success/reference_only_used_as_alias/src/main.nr @@ -0,0 +1,88 @@ +struct ExampleEvent0 { + value0: Field, + value1: Field, +} + +trait EventInterface { + fn emit(self, _emit: fn[Env](Self) -> ()); +} + +impl EventInterface for ExampleEvent0 { + fn emit(self: ExampleEvent0, _emit: fn[Env](Self) -> ()) { + _emit(self); + } +} + +struct ExampleEvent1 { + value2: u8, + value3: u8, +} + +struct Context { + a: u32, + b: [u32; 3], + log_hashes: BoundedVec, +} + +struct LogHash { + value: Field, + counter: u32, + length: Field, + randomness: Field, +} + +impl Context { + fn emit_raw_log(&mut self, randomness: Field, _log: [u8; M], log_hash: Field) { + let log_hash = LogHash { value: log_hash, counter: 0, length: 0, randomness }; + self.log_hashes.push(log_hash); + } +} + +fn compute(_event: Event) -> ([u8; 5], Field) where Event: EventInterface { + ([0 as u8; 5], 0) +} + +fn emit_with_keys( + context: &mut Context, + randomness: Field, + event: Event, + inner_compute: fn(Event) -> ([u8; OB], Field) +) where Event: EventInterface { + let (log, log_hash) = inner_compute(event); + context.emit_raw_log(randomness, log, log_hash); +} + +pub fn encode_event_with_randomness( + context: &mut Context, + randomness: Field +) -> fn[(&mut Context, Field)](Event) -> () where Event: EventInterface { + | e: Event | { + unsafe { + func(context.a); + } + emit_with_keys(context, randomness, e, compute); + } +} + +unconstrained fn func(input: u32) { + let mut var = input; + let ref = &mut &mut var; + + for _ in 0..1 { + **ref = 2; + } + + assert(var == 2); + assert(**ref == 2); +} + +// This test aims to allocate a reference which is aliased and only accessed through its alias +// across multiple blocks. We want to guarantee that this allocation is not removed. +fn main(input: [Field; 4], randomness: Field, context_input: u32) { + let b = [context_input, context_input, context_input]; + let mut context = Context { a: context_input, b, log_hashes: BoundedVec::new() }; + + let event0 = ExampleEvent0 { value0: input[0], value1: input[1] }; + event0.emit(encode_event_with_randomness(&mut context, randomness)); +} + diff --git a/tooling/debugger/ignored-tests.txt b/tooling/debugger/ignored-tests.txt index 745971d9b28..78e14397938 100644 --- a/tooling/debugger/ignored-tests.txt +++ b/tooling/debugger/ignored-tests.txt @@ -3,4 +3,5 @@ debug_logs is_unconstrained macros references -regression_4709 \ No newline at end of file +regression_4709 +reference_only_used_as_alias \ No newline at end of file