From 13e62bda371b90371ac8fc9d7683425d3da563b9 Mon Sep 17 00:00:00 2001 From: Tom French Date: Wed, 13 Nov 2024 17:53:54 +0000 Subject: [PATCH] feat: avoid unnecessary ssa passes while loop unrolling --- .../noirc_evaluator/src/ssa/opt/unrolling.rs | 88 ++++++++++--------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 661109c1786..efd83775175 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -40,53 +40,46 @@ use fxhash::FxHashMap as HashMap; impl Ssa { /// Loop unrolling can return errors, since ACIR functions need to be fully unrolled. /// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found. + #[tracing::instrument(level = "trace", skip(ssa))] pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result { - // Try to unroll loops first: - let mut unroll_errors; - (ssa, unroll_errors) = ssa.try_to_unroll_loops(); - - // Keep unrolling until no more errors are found - while !unroll_errors.is_empty() { - let prev_unroll_err_count = unroll_errors.len(); - - // Simplify the SSA before retrying - - // Do a mem2reg after the last unroll to aid simplify_cfg - ssa = ssa.mem2reg(); - ssa = ssa.simplify_cfg(); - // Do another mem2reg after simplify_cfg to aid the next unroll - ssa = ssa.mem2reg(); - - // Unroll again - (ssa, unroll_errors) = ssa.try_to_unroll_loops(); - // If we didn't manage to unroll any more loops, exit - if unroll_errors.len() >= prev_unroll_err_count { - return Err(unroll_errors.swap_remove(0)); + for (_, function) in ssa.functions.iter_mut() { + // Try to unroll loops first: + let unroll_errors = function.try_to_unroll_loops(); + + // Keep unrolling until no more errors are found + while !unroll_errors.is_empty() { + let prev_unroll_err_count = unroll_errors.len(); + + // Simplify the SSA before retrying + + // Do a mem2reg after the last unroll to aid simplify_cfg + function.mem2reg(); + function.simplify_function(); + // Do another mem2reg after simplify_cfg to aid the next unroll + function.mem2reg(); + + // Unroll again + let mut unroll_errors = function.try_to_unroll_loops(); + // If we didn't manage to unroll any more loops, exit + if unroll_errors.len() >= prev_unroll_err_count { + return Err(unroll_errors.swap_remove(0)); + } } } - Ok(ssa) - } - /// Tries to unroll all loops in each SSA function. - /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. - /// Returns the ssa along with all unrolling errors encountered - #[tracing::instrument(level = "trace", skip(self))] - pub(crate) fn try_to_unroll_loops(mut self) -> (Ssa, Vec) { - let mut errors = vec![]; - for function in self.functions.values_mut() { - function.try_to_unroll_loops(&mut errors); - } - (self, errors) + Ok(ssa) } } impl Function { - pub(crate) fn try_to_unroll_loops(&mut self, errors: &mut Vec) { + fn try_to_unroll_loops(&mut self) -> Vec { // Loop unrolling in brillig can lead to a code explosion currently. This can // also be true for ACIR, but we have no alternative to unrolling in ACIR. // Brillig also generally prefers smaller code rather than faster code. if !matches!(self.runtime(), RuntimeType::Brillig(_)) { - errors.extend(find_all_loops(self).unroll_each_loop(self)); + find_all_loops(self).unroll_each_loop(self) + } else { + Vec::new() } } } @@ -507,11 +500,26 @@ impl<'f> LoopIteration<'f> { #[cfg(test)] mod tests { - use crate::ssa::{ - function_builder::FunctionBuilder, - ir::{instruction::BinaryOp, map::Id, types::Type}, + use crate::{ + errors::RuntimeError, + ssa::{ + function_builder::FunctionBuilder, + ir::{instruction::BinaryOp, map::Id, types::Type}, + }, }; + use super::Ssa; + + /// Tries to unroll all loops in each SSA function. + /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. + fn try_to_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec) { + let mut errors = vec![]; + for function in ssa.functions.values_mut() { + function.try_to_unroll_loops(&mut errors); + } + (ssa, errors) + } + #[test] fn unroll_nested_loops() { // fn main() { @@ -630,7 +638,7 @@ mod tests { // } // The final block count is not 1 because unrolling creates some unnecessary jmps. // If a simplify cfg pass is ran afterward, the expected block count will be 1. - let (ssa, errors) = ssa.try_to_unroll_loops(); + let (ssa, errors) = try_to_unroll_loops(ssa); assert_eq!(errors.len(), 0, "All loops should be unrolled"); assert_eq!(ssa.main().reachable_blocks().len(), 5); } @@ -680,7 +688,7 @@ mod tests { assert_eq!(ssa.main().reachable_blocks().len(), 4); // Expected that we failed to unroll the loop - let (_, errors) = ssa.try_to_unroll_loops(); + let (_, errors) = try_to_unroll_loops(ssa); assert_eq!(errors.len(), 1, "Expected to fail to unroll loop"); } }