From e1907e7fddb1968700778a81bf819b294fed17ef Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 10 Dec 2024 07:58:59 -0600 Subject: [PATCH 1/3] map values mut --- .../src/ssa/function_builder/data_bus.rs | 14 ++++ .../src/ssa/ir/function_inserter.rs | 13 ++-- .../noirc_evaluator/src/ssa/ir/instruction.rs | 68 ++++++++++++++++++- .../src/ssa/opt/constant_folding.rs | 7 +- .../src/ssa/opt/flatten_cfg.rs | 21 +++--- .../noirc_evaluator/src/ssa/opt/mem2reg.rs | 5 +- .../src/ssa/opt/normalize_value_ids.rs | 7 +- .../noirc_evaluator/src/ssa/opt/unrolling.rs | 7 +- 8 files changed, 110 insertions(+), 32 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs index bd2585a3bfa..97e2c838f35 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs @@ -90,6 +90,20 @@ impl DataBus { DataBus { call_data, return_data: self.return_data.map(&mut f) } } + /// Updates the databus values in place with the provided function + pub(crate) fn map_values_mut(&mut self, mut f: impl FnMut(ValueId) -> ValueId) { + for cd in self.call_data.iter_mut() { + cd.array_id = f(cd.array_id); + + // Can't mutate a hashmap's keys so we need to collect into a new one. + cd.index_map = cd.index_map.iter().map(|(k, v)| (f(*k), *v)).collect(); + } + + if let Some(data) = self.return_data.as_mut() { + *data = f(*data); + } + } + pub(crate) fn call_data_array(&self) -> Vec<(u32, ValueId)> { self.call_data.iter().map(|cd| (cd.call_data_id, cd.array_id)).collect() } diff --git a/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs b/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs index 6ebd2aa1105..9ae0839c75c 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function_inserter.rs @@ -73,25 +73,24 @@ impl<'f> FunctionInserter<'f> { /// Get an instruction and make sure all the values in it are freshly resolved. pub(crate) fn map_instruction(&mut self, id: InstructionId) -> (Instruction, CallStack) { - ( - self.function.dfg[id].clone().map_values(|id| self.resolve(id)), - self.function.dfg.get_call_stack(id), - ) + let mut instruction = self.function.dfg[id].clone(); + instruction.map_values_mut(|id| self.resolve(id)); + (instruction, self.function.dfg.get_call_stack(id)) } /// Maps a terminator in place, replacing any ValueId in the terminator with the /// resolved version of that value id from this FunctionInserter's internal value mapping. pub(crate) fn map_terminator_in_place(&mut self, block: BasicBlockId) { let mut terminator = self.function.dfg[block].take_terminator(); - terminator.mutate_values(|value| self.resolve(value)); + terminator.map_values_mut(|value| self.resolve(value)); self.function.dfg[block].set_terminator(terminator); } /// Maps the data bus in place, replacing any ValueId in the data bus with the /// resolved version of that value id from this FunctionInserter's internal value mapping. pub(crate) fn map_data_bus_in_place(&mut self) { - let data_bus = self.function.dfg.data_bus.clone(); - let data_bus = data_bus.map_values(|value| self.resolve(value)); + let mut data_bus = self.function.dfg.data_bus.clone(); + data_bus.map_values_mut(|value| self.resolve(value)); self.function.dfg.data_bus = data_bus; } diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index ba212fdad96..cc4f5e357c1 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -671,6 +671,72 @@ impl Instruction { } } + /// Maps each ValueId inside this instruction to a new ValueId in place. + pub(crate) fn map_values_mut(&mut self, mut f: impl FnMut(ValueId) -> ValueId) { + match self { + Instruction::Binary(binary) => { + binary.lhs = f(binary.lhs); + binary.rhs = f(binary.rhs); + } + Instruction::Cast(value, _) => *value = f(*value), + Instruction::Not(value) => *value = f(*value), + Instruction::Truncate { value, bit_size: _, max_bit_size: _ } => { + *value = f(*value); + } + Instruction::Constrain(lhs, rhs, assert_message) => { + *lhs = f(*lhs); + *rhs = f(*rhs); + if let Some(error) = assert_message.as_mut() { + if let ConstrainError::Dynamic(_, _, payload_values) = error { + for value in payload_values { + *value = f(*value); + } + } + } + } + Instruction::Call { func, arguments } => { + *func = f(*func); + for argument in arguments { + *argument = f(*argument); + } + } + Instruction::Allocate => (), + Instruction::Load { address } => *address = f(*address), + Instruction::Store { address, value } => { + *address = f(*address); + *value = f(*value); + } + Instruction::EnableSideEffectsIf { condition } => { + *condition = f(*condition); + } + Instruction::ArrayGet { array, index } => { + *array = f(*array); + *index = f(*index); + } + Instruction::ArraySet { array, index, value, mutable: _ } => { + *array = f(*array); + *index = f(*index); + *value = f(*value); + } + Instruction::IncrementRc { value } => *value = f(*value), + Instruction::DecrementRc { value } => *value = f(*value), + Instruction::RangeCheck { value, max_bit_size: _, assert_message: _ } => { + *value = f(*value); + } + Instruction::IfElse { then_condition, then_value, else_condition, else_value } => { + *then_condition = f(*then_condition); + *then_value = f(*then_value); + *else_condition = f(*else_condition); + *else_value = f(*else_value); + } + Instruction::MakeArray { elements, typ: _ } => { + for element in elements.iter_mut() { + *element = f(*element); + } + } + } + } + /// Applies a function to each input value this instruction holds. pub(crate) fn for_each_value(&self, mut f: impl FnMut(ValueId) -> T) { match self { @@ -1195,7 +1261,7 @@ impl TerminatorInstruction { } /// Mutate each ValueId to a new ValueId using the given mapping function - pub(crate) fn mutate_values(&mut self, mut f: impl FnMut(ValueId) -> ValueId) { + pub(crate) fn map_values_mut(&mut self, mut f: impl FnMut(ValueId) -> ValueId) { use TerminatorInstruction::*; match self { JmpIf { condition, .. } => { diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 56029a8fbd4..a5a69133297 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -372,7 +372,7 @@ impl<'brillig> Context<'brillig> { dom: &mut DominatorTree, constraint_simplification_mapping: &HashMap, ) -> Instruction { - let instruction = dfg[instruction_id].clone(); + let mut instruction = dfg[instruction_id].clone(); // Alternate between resolving `value_id` in the `dfg` and checking to see if the resolved value // has been constrained to be equal to some simpler value in the current block. @@ -400,9 +400,10 @@ impl<'brillig> Context<'brillig> { } // Resolve any inputs to ensure that we're comparing like-for-like instructions. - instruction.map_values(|value_id| { + instruction.map_values_mut(|value_id| { resolve_cache(block, dfg, dom, constraint_simplification_mapping, value_id) - }) + }); + instruction } /// Pushes a new [`Instruction`] into the [`DataFlowGraph`] which applies any optimizations diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 3fbccf93ec9..c52da70fff3 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -1218,20 +1218,17 @@ mod test { ) -> Vec { match dfg[value] { Value::Instruction { instruction, .. } => { - let mut values = vec![]; - dfg[instruction].map_values(|value| { - values.push(value); - value - }); + let mut constants = vec![]; - let mut values: Vec<_> = values - .into_iter() - .flat_map(|value| get_all_constants_reachable_from_instruction(dfg, value)) - .collect(); + dfg[instruction].for_each_value(|value| { + for constant in get_all_constants_reachable_from_instruction(dfg, value) { + constants.push(constant); + } + }); - values.sort(); - values.dedup(); - values + constants.sort(); + constants.dedup(); + constants } Value::NumericConstant { constant, .. } => vec![constant.to_u128()], _ => Vec::new(), diff --git a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs index 77ad53df9cf..1e5cd8bdfbd 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mem2reg.rs @@ -599,8 +599,9 @@ impl<'f> PerFunctionContext<'f> { } fn update_data_bus(&mut self) { - let databus = self.inserter.function.dfg.data_bus.clone(); - self.inserter.function.dfg.data_bus = databus.map_values(|t| self.inserter.resolve(t)); + let mut databus = self.inserter.function.dfg.data_bus.clone(); + databus.map_values_mut(|t| self.inserter.resolve(t)); + self.inserter.function.dfg.data_bus = databus; } fn handle_terminator(&mut self, block: BasicBlockId, references: &mut Block) { diff --git a/compiler/noirc_evaluator/src/ssa/opt/normalize_value_ids.rs b/compiler/noirc_evaluator/src/ssa/opt/normalize_value_ids.rs index a5b60fb5fcd..9485f89450b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/normalize_value_ids.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/normalize_value_ids.rs @@ -109,9 +109,10 @@ impl Context { } let old_block = &mut old_function.dfg[old_block_id]; - let mut terminator = old_block - .take_terminator() - .map_values(|value| self.new_ids.map_value(new_function, old_function, value)); + let mut terminator = old_block.take_terminator(); + terminator + .map_values_mut(|value| self.new_ids.map_value(new_function, old_function, value)); + terminator.mutate_blocks(|old_block| self.new_ids.blocks[&old_block]); new_function.dfg.set_block_terminator(new_block_id, terminator); } diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 142447c83a5..a64ddbc67ff 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -938,10 +938,9 @@ impl<'f> LoopIteration<'f> { } self.inserter.push_instruction(instruction, self.insert_block); } - let mut terminator = self.dfg()[self.source_block] - .unwrap_terminator() - .clone() - .map_values(|value| self.inserter.resolve(value)); + let mut terminator = self.dfg()[self.source_block].unwrap_terminator().clone(); + + terminator.map_values_mut(|value| self.inserter.resolve(value)); // Replace the blocks in the terminator with fresh one with the same parameters, // while remembering which were the original block IDs. From 291419e69a50d8405d15d21b686b54ab2966d9f8 Mon Sep 17 00:00:00 2001 From: Jake Fecher Date: Tue, 10 Dec 2024 08:02:24 -0600 Subject: [PATCH 2/3] Clippy --- compiler/noirc_evaluator/src/ssa/ir/instruction.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index cc4f5e357c1..5e189921116 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -686,11 +686,9 @@ impl Instruction { Instruction::Constrain(lhs, rhs, assert_message) => { *lhs = f(*lhs); *rhs = f(*rhs); - if let Some(error) = assert_message.as_mut() { - if let ConstrainError::Dynamic(_, _, payload_values) = error { - for value in payload_values { - *value = f(*value); - } + if let Some(ConstrainError::Dynamic(_, _, payload_values)) = assert_message { + for value in payload_values { + *value = f(*value); } } } From 93a21fd77386a5b3ebf4c0d053384bc396437cd9 Mon Sep 17 00:00:00 2001 From: jfecher Date: Fri, 13 Dec 2024 07:48:17 -0600 Subject: [PATCH 3/3] Update compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs Co-authored-by: Ary Borenszweig --- compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index a8e94f239de..fa2600a3356 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -1210,9 +1210,7 @@ mod test { let mut constants = vec![]; dfg[instruction].for_each_value(|value| { - for constant in get_all_constants_reachable_from_instruction(dfg, value) { - constants.push(constant); - } + constants.extend(get_all_constants_reachable_from_instruction(dfg, value)); }); constants.sort();