diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index a48c57cdb5..53a31ae57c 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -18,6 +18,7 @@ //! - A reference with 0 aliases means we were unable to find which reference this reference //! refers to. If such a reference is stored to, we must conservatively invalidate every //! reference in the current block. +//! - We also track the last load instruction to each address per block. //! //! From there, to figure out the value of each reference at the end of block, iterate each instruction: //! - On `Instruction::Allocate`: @@ -28,6 +29,13 @@ //! - Furthermore, if the result of the load is a reference, mark the result as an alias //! of the reference it dereferences to (if known). //! - If which reference it dereferences to is not known, this load result has no aliases. +//! - We also track the last instance of a load instruction to each address in a block. +//! If we see that the last load instruction was from the same address as the current load instruction, +//! we move to replace the result of the current load with the result of the previous load. +//! This removal requires a couple conditions: +//! - No store occurs to that address before the next load, +//! - The address is not used as an argument to a call +//! This optimization helps us remove repeated loads for which there are not known values. //! - On `Instruction::Store { address, value }`: //! - If the address of the store is known: //! - If the address has exactly 1 alias: @@ -40,11 +48,13 @@ //! - Conservatively mark every alias in the block to `Unknown`. //! - Additionally, if there were no Loads to any alias of the address between this Store and //! the previous Store to the same address, the previous store can be removed. +//! - Remove the instance of the last load instruction to the address and its aliases //! - On `Instruction::Call { arguments }`: //! - If any argument of the call is a reference, set the value of each alias of that //! reference to `Unknown` //! - Any builtin functions that may return aliases if their input also contains a //! reference should be tracked. Examples: `slice_push_back`, `slice_insert`, `slice_remove`, etc. +//! - Remove the instance of the last load instruction for any reference arguments and their aliases //! //! On a terminator instruction: //! - If the terminator is a `Jmp`: @@ -274,6 +284,9 @@ impl<'f> PerFunctionContext<'f> { if let Some(first_predecessor) = predecessors.next() { let mut first = self.blocks.get(&first_predecessor).cloned().unwrap_or_default(); first.last_stores.clear(); + // Last loads are tracked per block. During unification we are creating a new block from the current one, + // so we must clear the last loads of the current block before we return the new block. + first.last_loads.clear(); // Note that we have to start folding with the first block as the accumulator. // If we started with an empty block, an empty block union'd with any other block @@ -410,6 +423,28 @@ impl<'f> PerFunctionContext<'f> { self.last_loads.insert(address, (instruction, block_id)); } + + // Check whether the block has a repeat load from the same address (w/ no calls or stores in between the loads). + // If we do have a repeat load, we can remove the current load and map its result to the previous load's result. + if let Some(last_load) = references.last_loads.get(&address) { + let Instruction::Load { address: previous_address } = + &self.inserter.function.dfg[*last_load] + else { + panic!("Expected a Load instruction here"); + }; + let result = self.inserter.function.dfg.instruction_results(instruction)[0]; + let previous_result = + self.inserter.function.dfg.instruction_results(*last_load)[0]; + if *previous_address == address { + self.inserter.map_value(result, previous_result); + self.instructions_to_remove.insert(instruction); + } + } + // We want to set the load for every load even if the address has a known value + // and the previous load instruction was removed. + // We are safe to still remove a repeat load in this case as we are mapping from the current load's + // result to the previous load, which if it was removed should already have a mapping to the known value. + references.set_last_load(address, instruction); } Instruction::Store { address, value } => { let address = self.inserter.function.dfg.resolve(*address); @@ -435,6 +470,8 @@ impl<'f> PerFunctionContext<'f> { } references.set_known_value(address, value); + // If we see a store to an address, the last load to that address needs to remain. + references.keep_last_load_for(address, self.inserter.function); references.last_stores.insert(address, instruction); } Instruction::Allocate => { @@ -542,6 +579,9 @@ impl<'f> PerFunctionContext<'f> { let value = self.inserter.function.dfg.resolve(*value); references.set_unknown(value); references.mark_value_used(value, self.inserter.function); + + // If a reference is an argument to a call, the last load to that address and its aliases needs to remain. + references.keep_last_load_for(value, self.inserter.function); } } } @@ -572,6 +612,12 @@ impl<'f> PerFunctionContext<'f> { let destination_parameters = self.inserter.function.dfg[*destination].parameters(); assert_eq!(destination_parameters.len(), arguments.len()); + // If we have multiple parameters that alias that same argument value, + // then those parameters also alias each other. + // We save parameters with repeat arguments to later mark those + // parameters as aliasing one another. + let mut arg_set: HashMap> = HashMap::default(); + // Add an alias for each reference parameter for (parameter, argument) in destination_parameters.iter().zip(arguments) { if self.inserter.function.dfg.value_is_reference(*parameter) { @@ -581,10 +627,27 @@ impl<'f> PerFunctionContext<'f> { if let Some(aliases) = references.aliases.get_mut(expression) { // The argument reference is possibly aliased by this block parameter aliases.insert(*parameter); + + // Check if we have seen the same argument + let seen_parameters = arg_set.entry(argument).or_default(); + // Add the current parameter to the parameters we have seen for this argument. + // The previous parameters and the current one alias one another. + seen_parameters.insert(*parameter); } } } } + + // Set the aliases of the parameters + for (_, aliased_params) in arg_set { + for param in aliased_params.iter() { + self.set_aliases( + references, + *param, + AliasSet::known_multiple(aliased_params.clone()), + ); + } + } } TerminatorInstruction::Return { return_values, .. } => { // Removing all `last_stores` for each returned reference is more important here @@ -900,7 +963,7 @@ mod tests { // v10 = eq v9, Field 2 // constrain v9 == Field 2 // v11 = load v2 - // v12 = load v10 + // v12 = load v11 // v13 = eq v12, Field 2 // constrain v11 == Field 2 // return @@ -959,7 +1022,7 @@ mod tests { let main = ssa.main(); assert_eq!(main.reachable_blocks().len(), 4); - // The store from the original SSA should remain + // The stores from the original SSA should remain assert_eq!(count_stores(main.entry_block(), &main.dfg), 2); assert_eq!(count_stores(b2, &main.dfg), 1); @@ -1006,4 +1069,160 @@ mod tests { let main = ssa.main(); assert_eq!(count_loads(main.entry_block(), &main.dfg), 1); } + + #[test] + fn remove_repeat_loads() { + // This tests starts with two loads from the same unknown load. + // Specifically you should look for `load v2` in `b3`. + // We should be able to remove the second repeated load. + let src = " + acir(inline) fn main f0 { + b0(): + v0 = allocate -> &mut Field + store Field 0 at v0 + v2 = allocate -> &mut &mut Field + store v0 at v2 + jmp b1(Field 0) + b1(v3: Field): + v4 = eq v3, Field 0 + jmpif v4 then: b2, else: b3 + b2(): + v5 = load v2 -> &mut Field + store Field 2 at v5 + v8 = add v3, Field 1 + jmp b1(v8) + b3(): + v9 = load v0 -> Field + v10 = eq v9, Field 2 + constrain v9 == Field 2 + v11 = load v2 -> &mut Field + v12 = load v2 -> &mut Field + v13 = load v12 -> Field + v14 = eq v13, Field 2 + constrain v13 == Field 2 + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + // The repeated load from v3 should be removed + // b3 should only have three loads now rather than four previously + // + // All stores are expected to remain. + let expected = " + acir(inline) fn main f0 { + b0(): + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = allocate -> &mut &mut Field + store v1 at v3 + jmp b1(Field 0) + b1(v0: Field): + v4 = eq v0, Field 0 + jmpif v4 then: b3, else: b2 + b3(): + v11 = load v3 -> &mut Field + store Field 2 at v11 + v13 = add v0, Field 1 + jmp b1(v13) + b2(): + v5 = load v1 -> Field + v7 = eq v5, Field 2 + constrain v5 == Field 2 + v8 = load v3 -> &mut Field + v9 = load v8 -> Field + v10 = eq v9, Field 2 + constrain v9 == Field 2 + return + } + "; + + let ssa = ssa.mem2reg(); + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn keep_repeat_loads_passed_to_a_call() { + // The test is the exact same as `remove_repeat_loads` above except with the call + // to `f1` between the repeated loads. + let src = " + acir(inline) fn main f0 { + b0(): + v1 = allocate -> &mut Field + store Field 0 at v1 + v3 = allocate -> &mut &mut Field + store v1 at v3 + jmp b1(Field 0) + b1(v0: Field): + v4 = eq v0, Field 0 + jmpif v4 then: b3, else: b2 + b3(): + v13 = load v3 -> &mut Field + store Field 2 at v13 + v15 = add v0, Field 1 + jmp b1(v15) + b2(): + v5 = load v1 -> Field + v7 = eq v5, Field 2 + constrain v5 == Field 2 + v8 = load v3 -> &mut Field + call f1(v3) + v10 = load v3 -> &mut Field + v11 = load v10 -> Field + v12 = eq v11, Field 2 + constrain v11 == Field 2 + return + } + acir(inline) fn foo f1 { + b0(v0: &mut Field): + return + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.mem2reg(); + // We expect the program to be unchanged + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn keep_repeat_loads_with_alias_store() { + // v7, v8, and v9 alias one another. We want to make sure that a repeat load to v7 with a store + // to its aliases in between the repeat loads does not remove those loads. + let src = " + acir(inline) fn main f0 { + b0(v0: u1): + jmpif v0 then: b2, else: b1 + b2(): + v6 = allocate -> &mut Field + store Field 0 at v6 + jmp b3(v6, v6, v6) + b3(v1: &mut Field, v2: &mut Field, v3: &mut Field): + v8 = load v1 -> Field + store Field 2 at v2 + v10 = load v1 -> Field + store Field 1 at v3 + v11 = load v1 -> Field + store Field 3 at v3 + v13 = load v1 -> Field + constrain v8 == Field 0 + constrain v10 == Field 2 + constrain v11 == Field 1 + constrain v13 == Field 3 + return + b1(): + v4 = allocate -> &mut Field + store Field 1 at v4 + jmp b3(v4, v4, v4) + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + + let ssa = ssa.mem2reg(); + // We expect the program to be unchanged + assert_normalized_ssa_equals(ssa, src); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs index 4d768caa36..e32eaa7018 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg/alias_set.rs @@ -24,6 +24,10 @@ impl AliasSet { Self { aliases: Some(aliases) } } + pub(super) fn known_multiple(values: BTreeSet) -> AliasSet { + Self { aliases: Some(values) } + } + /// In rare cases, such as when creating an empty array of references, the set of aliases for a /// particular value will be known to be zero, which is distinct from being unknown and /// possibly referring to any alias. diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs index 532785d292..f4265b2466 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg/block.rs @@ -34,6 +34,9 @@ pub(super) struct Block { /// The last instance of a `Store` instruction to each address in this block pub(super) last_stores: im::OrdMap, + + // The last instance of a `Load` instruction to each address in this block + pub(super) last_loads: im::OrdMap, } /// An `Expression` here is used to represent a canonical key @@ -237,4 +240,14 @@ impl Block { Cow::Owned(AliasSet::unknown()) } + + pub(super) fn set_last_load(&mut self, address: ValueId, instruction: InstructionId) { + self.last_loads.insert(address, instruction); + } + + pub(super) fn keep_last_load_for(&mut self, address: ValueId, function: &Function) { + let address = function.dfg.resolve(address); + self.last_loads.remove(&address); + self.for_each_alias_of(address, |block, alias| block.last_loads.remove(&alias)); + } } diff --git a/tooling/debugger/tests/debug.rs b/tooling/debugger/tests/debug.rs index 2dca6b95f0..eb43cf9cc6 100644 --- a/tooling/debugger/tests/debug.rs +++ b/tooling/debugger/tests/debug.rs @@ -12,7 +12,7 @@ mod tests { let nargo_bin = cargo_bin("nargo").into_os_string().into_string().expect("Cannot parse nargo path"); - let timeout_seconds = 25; + let timeout_seconds = 30; let mut dbg_session = spawn_bash(Some(timeout_seconds * 1000)).expect("Could not start bash session");