From 461084e7359c0f05bba651b87d39375d96e9f90c Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 11 Apr 2024 10:57:07 +0000 Subject: [PATCH 1/7] feat: Unroll loops iterativelly --- compiler/noirc_evaluator/src/ssa.rs | 27 +++++++++++++- .../noirc_evaluator/src/ssa/opt/unrolling.rs | 36 +++++++++---------- .../execution_success/slice_loop/Nargo.toml | 6 ++++ .../execution_success/slice_loop/Prover.toml | 11 ++++++ .../execution_success/slice_loop/src/main.nr | 26 ++++++++++++++ 5 files changed, 85 insertions(+), 21 deletions(-) create mode 100644 test_programs/execution_success/slice_loop/Nargo.toml create mode 100644 test_programs/execution_success/slice_loop/Prover.toml create mode 100644 test_programs/execution_success/slice_loop/src/main.nr diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index fac7a7c0829..ab0ee68bf60 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -55,7 +55,7 @@ pub(crate) fn optimize_into_acir( .run_pass(Ssa::mem2reg, "After Mem2Reg:") .run_pass(Ssa::as_slice_optimization, "After `as_slice` optimization") .try_run_pass(Ssa::evaluate_assert_constant, "After Assert Constant:")? - .try_run_pass(Ssa::unroll_loops, "After Unrolling:")? + .try_run_pass(unroll_all_acir_loops, "After Unrolling:")? .run_pass(Ssa::simplify_cfg, "After Simplifying:") .run_pass(Ssa::flatten_cfg, "After Flattening:") .run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:") @@ -75,6 +75,31 @@ pub(crate) fn optimize_into_acir( time("SSA to ACIR", print_timings, || ssa.into_acir(&brillig, abi_distinctness)) } +/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled. +/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found. +fn unroll_all_acir_loops(mut ssa: Ssa) -> Result { + // Try to unroll loops first: + let mut unroll_errors; + (ssa, unroll_errors) = ssa.try_to_unroll_loops(); + + // Keep unrolling until no more errors are found + while !unroll_errors.is_empty() { + let prev_unroll_err_count = unroll_errors.len(); + + // Simplify the SSA before retrying + ssa = ssa.simplify_cfg(); + ssa = ssa.mem2reg(); + + // Unroll again + (ssa, unroll_errors) = ssa.try_to_unroll_loops(); + // If we didn't manage to unroll any more loops, exit + if unroll_errors.len() == prev_unroll_err_count { + return Err(unroll_errors[0].clone()); + } + } + Ok(ssa) +} + // Helper to time SSA passes fn time(name: &str, print_timings: bool, f: impl FnOnce() -> T) -> T { let start_time = chrono::Utc::now().time(); diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 8110e3469f1..0950bb84741 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -34,10 +34,12 @@ use crate::{ use fxhash::FxHashMap as HashMap; impl Ssa { - /// Unroll all loops in each SSA function. + /// Tries to unroll all loops in each SSA function. /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. + /// Returns the ssa along with all unrolling errors encountered #[tracing::instrument(level = "trace", skip(self))] - pub(crate) fn unroll_loops(mut self) -> Result { + pub(crate) fn try_to_unroll_loops(mut self) -> (Ssa, Vec) { + let mut errors = vec![]; for function in self.functions.values_mut() { // Loop unrolling in brillig can lead to a code explosion currently. This can // also be true for ACIR, but we have no alternative to unrolling in ACIR. @@ -46,12 +48,9 @@ impl Ssa { continue; } - // This check is always true with the addition of the above guard, but I'm - // keeping it in case the guard on brillig functions is ever removed. - let abort_on_error = matches!(function.runtime(), RuntimeType::Acir(_)); - find_all_loops(function).unroll_each_loop(function, abort_on_error)?; + errors.extend(find_all_loops(function).unroll_each_loop(function)); } - Ok(self) + (self, errors) } } @@ -115,34 +114,29 @@ fn find_all_loops(function: &Function) -> Loops { impl Loops { /// Unroll all loops within a given function. /// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified. - fn unroll_each_loop( - mut self, - function: &mut Function, - abort_on_error: bool, - ) -> Result<(), RuntimeError> { + fn unroll_each_loop(mut self, function: &mut Function) -> Vec { + let mut unroll_errors = vec![]; while let Some(next_loop) = self.yet_to_unroll.pop() { // If we've previously modified a block in this loop we need to refresh the context. // This happens any time we have nested loops. if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) { let mut new_context = find_all_loops(function); new_context.failed_to_unroll = self.failed_to_unroll; - return new_context.unroll_each_loop(function, abort_on_error); + return new_context.unroll_each_loop(function); } // Don't try to unroll the loop again if it is known to fail if !self.failed_to_unroll.contains(&next_loop.header) { match unroll_loop(function, &self.cfg, &next_loop) { Ok(_) => self.modified_blocks.extend(next_loop.blocks), - Err(call_stack) if abort_on_error => { - return Err(RuntimeError::UnknownLoopBound { call_stack }); - } - Err(_) => { + Err(call_stack) => { self.failed_to_unroll.insert(next_loop.header); + unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack }); } } } } - Ok(()) + unroll_errors } } @@ -585,7 +579,8 @@ mod tests { // } // The final block count is not 1 because unrolling creates some unnecessary jmps. // If a simplify cfg pass is ran afterward, the expected block count will be 1. - let ssa = ssa.unroll_loops().expect("All loops should be unrolled"); + let (ssa, errors) = ssa.try_to_unroll_loops(); + assert_eq!(errors.len(), 0, "All loops should be unrolled"); assert_eq!(ssa.main().reachable_blocks().len(), 5); } @@ -634,6 +629,7 @@ mod tests { assert_eq!(ssa.main().reachable_blocks().len(), 4); // Expected that we failed to unroll the loop - assert!(ssa.unroll_loops().is_err()); + let (_, errors) = ssa.try_to_unroll_loops(); + assert_eq!(errors.len(), 1, "Expected to fail to unroll loop"); } } diff --git a/test_programs/execution_success/slice_loop/Nargo.toml b/test_programs/execution_success/slice_loop/Nargo.toml new file mode 100644 index 00000000000..09ad90c4187 --- /dev/null +++ b/test_programs/execution_success/slice_loop/Nargo.toml @@ -0,0 +1,6 @@ +[package] +name = "slice_loop" +type = "bin" +authors = [""] + +[dependencies] diff --git a/test_programs/execution_success/slice_loop/Prover.toml b/test_programs/execution_success/slice_loop/Prover.toml new file mode 100644 index 00000000000..089a1764b54 --- /dev/null +++ b/test_programs/execution_success/slice_loop/Prover.toml @@ -0,0 +1,11 @@ +[[points]] +x = "1" +y = "2" + +[[points]] +x = "3" +y = "4" + +[[points]] +x = "5" +y = "6" diff --git a/test_programs/execution_success/slice_loop/src/main.nr b/test_programs/execution_success/slice_loop/src/main.nr new file mode 100644 index 00000000000..b438e5ab95b --- /dev/null +++ b/test_programs/execution_success/slice_loop/src/main.nr @@ -0,0 +1,26 @@ +struct Point { + x: Field, + y: Field, +} + +impl Point { + fn serialize(self) -> [Field; 2] { + [self.x, self.y] + } +} + +fn sum(values: [Field]) -> Field { + let mut sum = 0; + for value in values { + sum = sum + value; + } + sum +} + +fn main(points: [Point; 3]) { + let mut serialized_points = &[]; + for point in points { + serialized_points = serialized_points.append(point.serialize().as_slice()); + } + assert_eq(sum(serialized_points), 21); +} From 0827c5c11df78410215473077dc50166a7aa07b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 11 Apr 2024 16:28:49 +0200 Subject: [PATCH 2/7] Update compiler/noirc_evaluator/src/ssa.rs Co-authored-by: jfecher --- compiler/noirc_evaluator/src/ssa.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index ab0ee68bf60..45a027d6372 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -94,7 +94,7 @@ fn unroll_all_acir_loops(mut ssa: Ssa) -> Result { (ssa, unroll_errors) = ssa.try_to_unroll_loops(); // If we didn't manage to unroll any more loops, exit if unroll_errors.len() == prev_unroll_err_count { - return Err(unroll_errors[0].clone()); + return Err(unroll_errors.swap_remove(0)); } } Ok(ssa) From 7c39d023fa242031287d74803de034b268cff3ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Rodr=C3=ADguez?= Date: Thu, 11 Apr 2024 16:29:55 +0200 Subject: [PATCH 3/7] Update compiler/noirc_evaluator/src/ssa.rs Co-authored-by: jfecher --- compiler/noirc_evaluator/src/ssa.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 45a027d6372..26021a97916 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -93,7 +93,7 @@ fn unroll_all_acir_loops(mut ssa: Ssa) -> Result { // Unroll again (ssa, unroll_errors) = ssa.try_to_unroll_loops(); // If we didn't manage to unroll any more loops, exit - if unroll_errors.len() == prev_unroll_err_count { + if unroll_errors.len() >= prev_unroll_err_count { return Err(unroll_errors.swap_remove(0)); } } From 57d4be9f43712ee090a3e204988296613fd53782 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 11 Apr 2024 14:37:41 +0000 Subject: [PATCH 4/7] refactor: Address PR comments --- compiler/noirc_evaluator/src/ssa.rs | 27 +------------------ .../noirc_evaluator/src/ssa/opt/unrolling.rs | 25 +++++++++++++++++ 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 26021a97916..ce4a9bbe9f1 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -55,7 +55,7 @@ pub(crate) fn optimize_into_acir( .run_pass(Ssa::mem2reg, "After Mem2Reg:") .run_pass(Ssa::as_slice_optimization, "After `as_slice` optimization") .try_run_pass(Ssa::evaluate_assert_constant, "After Assert Constant:")? - .try_run_pass(unroll_all_acir_loops, "After Unrolling:")? + .try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")? .run_pass(Ssa::simplify_cfg, "After Simplifying:") .run_pass(Ssa::flatten_cfg, "After Flattening:") .run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:") @@ -75,31 +75,6 @@ pub(crate) fn optimize_into_acir( time("SSA to ACIR", print_timings, || ssa.into_acir(&brillig, abi_distinctness)) } -/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled. -/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found. -fn unroll_all_acir_loops(mut ssa: Ssa) -> Result { - // Try to unroll loops first: - let mut unroll_errors; - (ssa, unroll_errors) = ssa.try_to_unroll_loops(); - - // Keep unrolling until no more errors are found - while !unroll_errors.is_empty() { - let prev_unroll_err_count = unroll_errors.len(); - - // Simplify the SSA before retrying - ssa = ssa.simplify_cfg(); - ssa = ssa.mem2reg(); - - // Unroll again - (ssa, unroll_errors) = ssa.try_to_unroll_loops(); - // If we didn't manage to unroll any more loops, exit - if unroll_errors.len() >= prev_unroll_err_count { - return Err(unroll_errors.swap_remove(0)); - } - } - Ok(ssa) -} - // Helper to time SSA passes fn time(name: &str, print_timings: bool, f: impl FnOnce() -> T) -> T { let start_time = chrono::Utc::now().time(); diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 0950bb84741..d4a930bb0eb 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -34,6 +34,31 @@ use crate::{ use fxhash::FxHashMap as HashMap; impl Ssa { + /// Loop unrolling can return errors, since ACIR functions need to be fully unrolled. + /// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found. + pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result { + // Try to unroll loops first: + let mut unroll_errors; + (ssa, unroll_errors) = ssa.try_to_unroll_loops(); + + // Keep unrolling until no more errors are found + while !unroll_errors.is_empty() { + let prev_unroll_err_count = unroll_errors.len(); + + // Simplify the SSA before retrying + ssa = ssa.simplify_cfg(); + ssa = ssa.mem2reg(); + + // Unroll again + (ssa, unroll_errors) = ssa.try_to_unroll_loops(); + // If we didn't manage to unroll any more loops, exit + if unroll_errors.len() >= prev_unroll_err_count { + return Err(unroll_errors.swap_remove(0)); + } + } + Ok(ssa) + } + /// Tries to unroll all loops in each SSA function. /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. /// Returns the ssa along with all unrolling errors encountered From ccdce8530ede0b8b47be9fc5ac143675c5e18b9e Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 11 Apr 2024 14:50:52 +0000 Subject: [PATCH 5/7] fix: also handle compile-time if statements --- compiler/noirc_evaluator/src/ssa/opt/unrolling.rs | 4 +++- test_programs/execution_success/slice_loop/src/main.nr | 6 ++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index d4a930bb0eb..c48a18e8227 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -46,8 +46,10 @@ impl Ssa { let prev_unroll_err_count = unroll_errors.len(); // Simplify the SSA before retrying - ssa = ssa.simplify_cfg(); + + // Do a mem2reg after the last unroll to aid simplify_cfg ssa = ssa.mem2reg(); + ssa = ssa.simplify_cfg(); // Unroll again (ssa, unroll_errors) = ssa.try_to_unroll_loops(); diff --git a/test_programs/execution_success/slice_loop/src/main.nr b/test_programs/execution_success/slice_loop/src/main.nr index b438e5ab95b..0a971cac45f 100644 --- a/test_programs/execution_success/slice_loop/src/main.nr +++ b/test_programs/execution_success/slice_loop/src/main.nr @@ -22,5 +22,11 @@ fn main(points: [Point; 3]) { for point in points { serialized_points = serialized_points.append(point.serialize().as_slice()); } + // Do a compile-time check that needs the previous loop to be unrolled + if points.len() > 5 { + let empty_point = Point { x: 0, y: 0 }; + serialized_points = serialized_points.append(empty_point.serialize().as_slice()); + } + // Do a sum that needs both the previous loop and the previous if to have been simplified assert_eq(sum(serialized_points), 21); } From fc0ac9cde48d93da1d04e9d03f43f94e753cea4a Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 11 Apr 2024 14:55:11 +0000 Subject: [PATCH 6/7] test: fix test --- test_programs/execution_success/slice_loop/src/main.nr | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_programs/execution_success/slice_loop/src/main.nr b/test_programs/execution_success/slice_loop/src/main.nr index 0a971cac45f..4ff3e865b1f 100644 --- a/test_programs/execution_success/slice_loop/src/main.nr +++ b/test_programs/execution_success/slice_loop/src/main.nr @@ -23,7 +23,7 @@ fn main(points: [Point; 3]) { serialized_points = serialized_points.append(point.serialize().as_slice()); } // Do a compile-time check that needs the previous loop to be unrolled - if points.len() > 5 { + if serialized_points.len() > 5 { let empty_point = Point { x: 0, y: 0 }; serialized_points = serialized_points.append(empty_point.serialize().as_slice()); } From 355168b1bcddd81bdf2c26a47c7794a3cff542a3 Mon Sep 17 00:00:00 2001 From: sirasistant Date: Thu, 11 Apr 2024 14:58:27 +0000 Subject: [PATCH 7/7] fix: missing mem2reg --- compiler/noirc_evaluator/src/ssa/opt/unrolling.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index c48a18e8227..c6bf7923fa8 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -50,6 +50,8 @@ impl Ssa { // Do a mem2reg after the last unroll to aid simplify_cfg ssa = ssa.mem2reg(); ssa = ssa.simplify_cfg(); + // Do another mem2reg after simplify_cfg to aid the next unroll + ssa = ssa.mem2reg(); // Unroll again (ssa, unroll_errors) = ssa.try_to_unroll_loops();