diff --git a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs index 6848f84bb7b..50c97b765bb 100644 --- a/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs +++ b/crates/noirc_evaluator/src/ssa_refactor/opt/inlining.rs @@ -79,10 +79,6 @@ struct PerFunctionContext<'function> { /// Maps InstructionIds from the function being inlined to the function being inlined into. instructions: HashMap, - /// The TerminatorInstruction::Return in the source_function will be mapped to a jmp to - /// this block in the destination function instead. - return_destination: BasicBlockId, - /// True if we're currently working on the main function. inlining_main: bool, } @@ -124,7 +120,12 @@ impl InlineContext { /// Inlines a function into the current function and returns the translated return values /// of the inlined function. - fn inline_function(&mut self, ssa: &Ssa, id: FunctionId, arguments: &[ValueId]) -> &[ValueId] { + fn inline_function( + &mut self, + ssa: &Ssa, + id: FunctionId, + arguments: &[ValueId], + ) -> Vec { self.recursion_level += 1; if self.recursion_level > RECURSION_LIMIT { @@ -143,9 +144,7 @@ impl InlineContext { let current_block = context.context.builder.current_block(); context.blocks.insert(source_function.entry_block(), current_block); - context.inline_blocks(ssa); - let return_destination = context.return_destination; - self.builder.block_parameters(return_destination) + context.inline_blocks(ssa) } /// Finish inlining and return the new Ssa struct with the inlined version of main. @@ -175,10 +174,7 @@ impl<'function> PerFunctionContext<'function> { /// for containing the mapping between parameters in the source_function and /// the arguments of the destination function. fn new(context: &'function mut InlineContext, source_function: &'function Function) -> Self { - // Create the block to return to but don't insert its parameters until we - // have the types of the actual return values later. Self { - return_destination: context.builder.insert_block(), context, source_function, blocks: HashMap::new(), @@ -265,20 +261,60 @@ impl<'function> PerFunctionContext<'function> { } /// Inline all reachable blocks within the source_function into the destination function. - fn inline_blocks(&mut self, ssa: &Ssa) { + fn inline_blocks(&mut self, ssa: &Ssa) -> Vec { let mut seen_blocks = HashSet::new(); let mut block_queue = vec![self.source_function.entry_block()]; + // This Vec will contain each block with a Return instruction along with the + // returned values of that block. + let mut function_returns = vec![]; + while let Some(source_block_id) = block_queue.pop() { let translated_block_id = self.translate_block(source_block_id, &mut block_queue); self.context.builder.switch_to_block(translated_block_id); seen_blocks.insert(source_block_id); self.inline_block(ssa, source_block_id); - self.handle_terminator_instruction(source_block_id, &mut block_queue); + + if let Some((block, values)) = + self.handle_terminator_instruction(source_block_id, &mut block_queue) + { + function_returns.push((block, values)); + } } - self.context.builder.switch_to_block(self.return_destination); + self.handle_function_returns(function_returns) + } + + /// Handle inlining a function's possibly multiple return instructions. + /// If there is only 1 return we can just continue inserting into that block. + /// If there are multiple, we'll need to create a join block to jump to with each value. + fn handle_function_returns( + &mut self, + mut returns: Vec<(BasicBlockId, Vec)>, + ) -> Vec { + // Clippy complains if this were written as an if statement + match returns.len() { + 1 => { + let (return_block, return_values) = returns.remove(0); + self.context.builder.switch_to_block(return_block); + return_values + } + n if n > 1 => { + // If there is more than 1 return instruction we'll need to create a single block we + // can return to and continue inserting in afterwards. + let return_block = self.context.builder.insert_block(); + + for (block, return_values) in returns { + self.context.builder.switch_to_block(block); + self.context.builder.terminate_with_jmp(return_block, return_values); + } + + self.context.builder.switch_to_block(return_block); + self.context.builder.block_parameters(return_block).to_vec() + } + _ => unreachable!("Inlined function had no return values"), + } } /// Inline each instruction in the given block into the function being inlined into. @@ -307,7 +343,7 @@ impl<'function> PerFunctionContext<'function> { let old_results = self.source_function.dfg.instruction_results(call_id); let arguments = vecmap(arguments, |arg| self.translate_value(*arg)); let new_results = self.context.inline_function(ssa, function, &arguments); - Self::insert_new_instruction_results(&mut self.values, old_results, new_results); + Self::insert_new_instruction_results(&mut self.values, old_results, &new_results); } /// Push the given instruction from the source_function into the current block of the @@ -340,16 +376,20 @@ impl<'function> PerFunctionContext<'function> { /// Handle the given terminator instruction from the given source function block. /// This will push any new blocks to the destination function as needed, add them /// to the block queue, and set the terminator instruction for the current block. + /// + /// If the terminator instruction was a Return, this will return the block this instruction + /// was in as well as the values that were returned. fn handle_terminator_instruction( &mut self, block_id: BasicBlockId, block_queue: &mut Vec, - ) { + ) -> Option<(BasicBlockId, Vec)> { match self.source_function.dfg[block_id].terminator() { Some(TerminatorInstruction::Jmp { destination, arguments }) => { let destination = self.translate_block(*destination, block_queue); let arguments = vecmap(arguments, |arg| self.translate_value(*arg)); self.context.builder.terminate_with_jmp(destination, arguments); + None } Some(TerminatorInstruction::JmpIf { condition, @@ -360,21 +400,15 @@ impl<'function> PerFunctionContext<'function> { let then_block = self.translate_block(*then_destination, block_queue); let else_block = self.translate_block(*else_destination, block_queue); self.context.builder.terminate_with_jmpif(condition, then_block, else_block); + None } Some(TerminatorInstruction::Return { return_values }) => { let return_values = vecmap(return_values, |value| self.translate_value(*value)); - if self.inlining_main { - self.context.builder.terminate_with_return(return_values); - } else { - for value in &return_values { - // Add the block parameters for the return block here since we don't do - // it when inserting the block in PerFunctionContext::new - let typ = self.context.builder.current_function.dfg.type_of_value(*value); - self.context.builder.add_block_parameter(self.return_destination, typ); - } - self.context.builder.terminate_with_jmp(self.return_destination, return_values); + self.context.builder.terminate_with_return(return_values.clone()); } + let block_id = self.translate_block(block_id, block_queue); + Some((block_id, return_values)) } None => unreachable!("Block has no terminator instruction"), } @@ -384,7 +418,7 @@ impl<'function> PerFunctionContext<'function> { #[cfg(test)] mod test { use crate::ssa_refactor::{ - ir::{map::Id, types::Type}, + ir::{instruction::BinaryOp, map::Id, types::Type}, ssa_builder::FunctionBuilder, }; @@ -418,4 +452,70 @@ mod test { let inlined = ssa.inline_functions(); assert_eq!(inlined.functions.len(), 1); } + + #[test] + fn complex_inlining() { + // This SSA is from issue #1327 which previously failed to inline properly + // + // fn main f0 { + // b0(v0: Field): + // v7 = call f2(f1) + // v13 = call f3(v7) + // v16 = call v13(v0) + // return v16 + // } + // fn square f1 { + // b0(v0: Field): + // v2 = mul v0, v0 + // return v2 + // } + // fn id1 f2 { + // b0(v0: function): + // return v0 + // } + // fn id2 f3 { + // b0(v0: function): + // return v0 + // } + let main_id = Id::test_new(0); + let square_id = Id::test_new(1); + let id1_id = Id::test_new(2); + let id2_id = Id::test_new(3); + + // Compiling main + let mut builder = FunctionBuilder::new("main".into(), main_id); + let main_v0 = builder.add_parameter(Type::field()); + + let main_f1 = builder.import_function(square_id); + let main_f2 = builder.import_function(id1_id); + let main_f3 = builder.import_function(id2_id); + + let main_v7 = builder.insert_call(main_f2, vec![main_f1], vec![Type::Function])[0]; + let main_v13 = builder.insert_call(main_f3, vec![main_v7], vec![Type::Function])[0]; + let main_v16 = builder.insert_call(main_v13, vec![main_v0], vec![Type::field()])[0]; + builder.terminate_with_return(vec![main_v16]); + + // Compiling square f1 + builder.new_function("square".into(), square_id); + let square_v0 = builder.add_parameter(Type::field()); + let square_v2 = builder.insert_binary(square_v0, BinaryOp::Mul, square_v0); + builder.terminate_with_return(vec![square_v2]); + + // Compiling id1 f2 + builder.new_function("id1".into(), id1_id); + let id1_v0 = builder.add_parameter(Type::Function); + builder.terminate_with_return(vec![id1_v0]); + + // Compiling id2 f3 + builder.new_function("id2".into(), id2_id); + let id2_v0 = builder.add_parameter(Type::Function); + builder.terminate_with_return(vec![id2_v0]); + + // Done, now we test that we can successfully inline all functions. + let ssa = builder.finish(); + assert_eq!(ssa.functions.len(), 4); + + let inlined = ssa.inline_functions(); + assert_eq!(inlined.functions.len(), 1); + } }