From 4ff308128755c95b4d461bbcb7e3a49f16145585 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 29 Nov 2024 21:27:10 +0100 Subject: [PATCH] feat(ssa): Option to set the maximum acceptable Brillig bytecode increase in unrolling (#6641) --- Cargo.lock | 1 + compiler/noirc_driver/src/lib.rs | 11 +- compiler/noirc_evaluator/Cargo.toml | 1 + compiler/noirc_evaluator/src/brillig/mod.rs | 8 +- compiler/noirc_evaluator/src/ssa.rs | 19 ++- .../noirc_evaluator/src/ssa/ir/function.rs | 6 + .../noirc_evaluator/src/ssa/opt/unrolling.rs | 137 +++++++++++++++--- tooling/nargo_cli/build.rs | 5 + 8 files changed, 155 insertions(+), 33 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94a84b89d05..af91bafef52 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3151,6 +3151,7 @@ dependencies = [ "serde_json", "serde_with", "similar-asserts", + "test-case", "thiserror", "tracing", ] diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index 72ea464805f..a7cd9ff90ac 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -126,11 +126,19 @@ pub struct CompileOptions { #[arg(long)] pub skip_underconstrained_check: bool, - /// Setting to decide on an inlining strategy for brillig functions. + /// Setting to decide on an inlining strategy for Brillig functions. /// A more aggressive inliner should generate larger programs but more optimized /// A less aggressive inliner should generate smaller programs #[arg(long, hide = true, allow_hyphen_values = true, default_value_t = i64::MAX)] pub inliner_aggressiveness: i64, + + /// Setting the maximum acceptable increase in Brillig bytecode size due to + /// unrolling small loops. When left empty, any change is accepted as long + /// as it required fewer SSA instructions. + /// A higher value results in fewer jumps but a larger program. + /// A lower value keeps the original program if it was smaller, even if it has more jumps. + #[arg(long, hide = true, allow_hyphen_values = true)] + pub max_bytecode_increase_percent: Option, } pub fn parse_expression_width(input: &str) -> Result { @@ -589,6 +597,7 @@ pub fn compile_no_check( emit_ssa: if options.emit_ssa { Some(context.package_build_path.clone()) } else { None }, skip_underconstrained_check: options.skip_underconstrained_check, inliner_aggressiveness: options.inliner_aggressiveness, + max_bytecode_increase_percent: options.max_bytecode_increase_percent, }; let SsaProgramArtifact { program, debug, warnings, names, brillig_names, error_types, .. } = diff --git a/compiler/noirc_evaluator/Cargo.toml b/compiler/noirc_evaluator/Cargo.toml index e25b5bf855a..bb8c62cfd95 100644 --- a/compiler/noirc_evaluator/Cargo.toml +++ b/compiler/noirc_evaluator/Cargo.toml @@ -33,6 +33,7 @@ cfg-if.workspace = true proptest.workspace = true similar-asserts.workspace = true num-traits.workspace = true +test-case.workspace = true [features] bn254 = ["noirc_frontend/bn254"] diff --git a/compiler/noirc_evaluator/src/brillig/mod.rs b/compiler/noirc_evaluator/src/brillig/mod.rs index 1b61ae1a864..cb8c35cd8e0 100644 --- a/compiler/noirc_evaluator/src/brillig/mod.rs +++ b/compiler/noirc_evaluator/src/brillig/mod.rs @@ -12,7 +12,7 @@ use self::{ }, }; use crate::ssa::{ - ir::function::{Function, FunctionId, RuntimeType}, + ir::function::{Function, FunctionId}, ssa_gen::Ssa, }; use fxhash::FxHashMap as HashMap; @@ -59,7 +59,7 @@ impl std::ops::Index for Brillig { } impl Ssa { - /// Compile to brillig brillig functions and ACIR functions reachable from them + /// Compile Brillig functions and ACIR functions reachable from them #[tracing::instrument(level = "trace", skip_all)] pub(crate) fn to_brillig(&self, enable_debug_trace: bool) -> Brillig { // Collect all the function ids that are reachable from brillig @@ -67,9 +67,7 @@ impl Ssa { let brillig_reachable_function_ids = self .functions .iter() - .filter_map(|(id, func)| { - matches!(func.runtime(), RuntimeType::Brillig(_)).then_some(*id) - }) + .filter_map(|(id, func)| func.runtime().is_brillig().then_some(*id)) .collect::>(); let mut brillig = Brillig::default(); diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 97c1760d87c..80514b2f2cf 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -67,6 +67,11 @@ pub struct SsaEvaluatorOptions { /// The higher the value, the more inlined brillig functions will be. pub inliner_aggressiveness: i64, + + /// Maximum accepted percentage increase in the Brillig bytecode size after unrolling loops. + /// When `None` the size increase check is skipped altogether and any decrease in the SSA + /// instruction count is accepted. + pub max_bytecode_increase_percent: Option, } pub(crate) struct ArtifactsAndWarnings(Artifacts, Vec); @@ -104,7 +109,10 @@ pub(crate) fn optimize_into_acir( "After `static_assert` and `assert_constant`:", )? .run_pass(Ssa::loop_invariant_code_motion, "After Loop Invariant Code Motion:") - .try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")? + .try_run_pass( + |ssa| ssa.unroll_loops_iteratively(options.max_bytecode_increase_percent), + "After Unrolling:", + )? .run_pass(Ssa::simplify_cfg, "After Simplifying (2nd):") .run_pass(Ssa::flatten_cfg, "After Flattening:") .run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:") @@ -450,11 +458,10 @@ impl SsaBuilder { } /// The same as `run_pass` but for passes that may fail - fn try_run_pass( - mut self, - pass: fn(Ssa) -> Result, - msg: &str, - ) -> Result { + fn try_run_pass(mut self, pass: F, msg: &str) -> Result + where + F: FnOnce(Ssa) -> Result, + { self.ssa = time(msg, self.print_codegen_timings, || pass(self.ssa))?; Ok(self.print(msg)) } diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index b1233e3063e..6413107c04a 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -197,6 +197,12 @@ impl Function { } } +impl Clone for Function { + fn clone(&self) -> Self { + Function::clone_with_id(self.id(), self) + } +} + impl std::fmt::Display for RuntimeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 777c16dacd1..5883ce25936 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -19,8 +19,10 @@ //! When unrolling ACIR code, we remove reference count instructions because they are //! only used by Brillig bytecode. use acvm::{acir::AcirField, FieldElement}; +use im::HashSet; use crate::{ + brillig::brillig_gen::convert_ssa_function, errors::RuntimeError, ssa::{ ir::{ @@ -37,38 +39,60 @@ use crate::{ ssa_gen::Ssa, }, }; -use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use fxhash::FxHashMap as HashMap; impl Ssa { /// Loop unrolling can return errors, since ACIR functions need to be fully unrolled. /// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found. - #[tracing::instrument(level = "trace", skip(ssa))] - pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result { - for (_, function) in ssa.functions.iter_mut() { + /// + /// The `max_bytecode_incr_pct`, when given, is used to limit the growth of the Brillig bytecode size + /// after unrolling small loops to some percentage of the original loop. For example a value of 150 would + /// mean the new loop can be 150% (ie. 2.5 times) larger than the original loop. It will still contain + /// fewer SSA instructions, but that can still result in more Brillig opcodes. + #[tracing::instrument(level = "trace", skip(self))] + pub(crate) fn unroll_loops_iteratively( + mut self: Ssa, + max_bytecode_increase_percent: Option, + ) -> Result { + for (_, function) in self.functions.iter_mut() { + // Take a snapshot of the function to compare byte size increase, + // but only if the setting indicates we have to, otherwise skip it. + let orig_func_and_max_incr_pct = max_bytecode_increase_percent + .filter(|_| function.runtime().is_brillig()) + .map(|max_incr_pct| (function.clone(), max_incr_pct)); + // Try to unroll loops first: - let mut unroll_errors = function.try_unroll_loops(); + let (mut has_unrolled, mut unroll_errors) = function.try_unroll_loops(); // Keep unrolling until no more errors are found while !unroll_errors.is_empty() { let prev_unroll_err_count = unroll_errors.len(); // Simplify the SSA before retrying - - // Do a mem2reg after the last unroll to aid simplify_cfg - function.mem2reg(); - function.simplify_function(); - // Do another mem2reg after simplify_cfg to aid the next unroll - function.mem2reg(); + simplify_between_unrolls(function); // Unroll again - unroll_errors = function.try_unroll_loops(); + let (new_unrolled, new_errors) = function.try_unroll_loops(); + unroll_errors = new_errors; + has_unrolled |= new_unrolled; + // If we didn't manage to unroll any more loops, exit if unroll_errors.len() >= prev_unroll_err_count { return Err(unroll_errors.swap_remove(0)); } } + + if has_unrolled { + if let Some((orig_function, max_incr_pct)) = orig_func_and_max_incr_pct { + let new_size = brillig_bytecode_size(function); + let orig_size = brillig_bytecode_size(&orig_function); + if !is_new_size_ok(orig_size, new_size, max_incr_pct) { + *function = orig_function; + } + } + } } - Ok(ssa) + Ok(self) } } @@ -77,7 +101,7 @@ impl Function { // This can also be true for ACIR, but we have no alternative to unrolling in ACIR. // Brillig also generally prefers smaller code rather than faster code, // so we only attempt to unroll small loops, which we decide on a case-by-case basis. - fn try_unroll_loops(&mut self) -> Vec { + fn try_unroll_loops(&mut self) -> (bool, Vec) { Loops::find_all(self).unroll_each(self) } } @@ -170,8 +194,10 @@ impl Loops { /// Unroll all loops within a given function. /// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified. - fn unroll_each(mut self, function: &mut Function) -> Vec { + /// Returns whether any blocks have been modified + fn unroll_each(mut self, function: &mut Function) -> (bool, Vec) { let mut unroll_errors = vec![]; + let mut has_unrolled = false; while let Some(next_loop) = self.yet_to_unroll.pop() { if function.runtime().is_brillig() && !next_loop.is_small_loop(function, &self.cfg) { continue; @@ -181,13 +207,17 @@ impl Loops { if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) { let mut new_loops = Self::find_all(function); new_loops.failed_to_unroll = self.failed_to_unroll; - return unroll_errors.into_iter().chain(new_loops.unroll_each(function)).collect(); + let (new_unrolled, new_errors) = new_loops.unroll_each(function); + return (has_unrolled || new_unrolled, [unroll_errors, new_errors].concat()); } // Don't try to unroll the loop again if it is known to fail if !self.failed_to_unroll.contains(&next_loop.header) { match next_loop.unroll(function, &self.cfg) { - Ok(_) => self.modified_blocks.extend(next_loop.blocks), + Ok(_) => { + has_unrolled = true; + self.modified_blocks.extend(next_loop.blocks); + } Err(call_stack) => { self.failed_to_unroll.insert(next_loop.header); unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack }); @@ -195,7 +225,7 @@ impl Loops { } } } - unroll_errors + (has_unrolled, unroll_errors) } } @@ -947,21 +977,59 @@ impl<'f> LoopIteration<'f> { } } +/// Unrolling leaves some duplicate instructions which can potentially be removed. +fn simplify_between_unrolls(function: &mut Function) { + // Do a mem2reg after the last unroll to aid simplify_cfg + function.mem2reg(); + function.simplify_function(); + // Do another mem2reg after simplify_cfg to aid the next unroll + function.mem2reg(); +} + +/// Convert the function to Brillig bytecode and return the resulting size. +fn brillig_bytecode_size(function: &Function) -> usize { + // We need to do some SSA passes in order for the conversion to be able to go ahead, + // otherwise we can hit `unreachable!()` instructions in `convert_ssa_instruction`. + // Creating a clone so as not to modify the originals. + let mut temp = function.clone(); + + // Might as well give it the best chance. + simplify_between_unrolls(&mut temp); + + // This is to try to prevent hitting ICE. + temp.dead_instruction_elimination(false); + + convert_ssa_function(&temp, false).byte_code.len() +} + +/// Decide if the new bytecode size is acceptable, compared to the original. +/// +/// The maximum increase can be expressed as a negative value if we demand a decrease. +/// (Values -100 and under mean the new size should be 0). +fn is_new_size_ok(orig_size: usize, new_size: usize, max_incr_pct: i32) -> bool { + let max_size_pct = 100i32.saturating_add(max_incr_pct).max(0) as usize; + let max_size = orig_size.saturating_mul(max_size_pct); + new_size.saturating_mul(100) <= max_size +} + #[cfg(test)] mod tests { use acvm::FieldElement; + use test_case::test_case; use crate::errors::RuntimeError; use crate::ssa::{ir::value::ValueId, opt::assert_normalized_ssa_equals, Ssa}; - use super::{BoilerplateStats, Loops}; + use super::{is_new_size_ok, BoilerplateStats, Loops}; - /// Tries to unroll all loops in each SSA function. + /// Tries to unroll all loops in each SSA function once, calling the `Function` directly, + /// bypassing the iterative loop done by the SSA which does further optimisations. + /// /// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state. fn try_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec) { let mut errors = vec![]; for function in ssa.functions.values_mut() { - errors.extend(function.try_unroll_loops()); + errors.extend(function.try_unroll_loops().1); } (ssa, errors) } @@ -1221,9 +1289,26 @@ mod tests { let (ssa, errors) = try_unroll_loops(ssa); assert_eq!(errors.len(), 0, "Unroll should have no errors"); + // Check that it's still the original assert_normalized_ssa_equals(ssa, parse_ssa().to_string().as_str()); } + #[test] + fn test_brillig_unroll_iteratively_respects_max_increase() { + let ssa = brillig_unroll_test_case(); + let ssa = ssa.unroll_loops_iteratively(Some(-90)).unwrap(); + // Check that it's still the original + assert_normalized_ssa_equals(ssa, brillig_unroll_test_case().to_string().as_str()); + } + + #[test] + fn test_brillig_unroll_iteratively_with_large_max_increase() { + let ssa = brillig_unroll_test_case(); + let ssa = ssa.unroll_loops_iteratively(Some(50)).unwrap(); + // Check that it did the unroll + assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled"); + } + /// Test that `break` and `continue` stop unrolling without any panic. #[test] fn test_brillig_unroll_break_and_continue() { @@ -1377,4 +1462,14 @@ mod tests { let loop0 = loops.yet_to_unroll.pop().expect("there should be a loop"); loop0.boilerplate_stats(function, &loops.cfg).expect("there should be stats") } + + #[test_case(1000, 700, 50, true; "size decreased")] + #[test_case(1000, 1500, 50, true; "size increased just by the max")] + #[test_case(1000, 1501, 50, false; "size increased over the max")] + #[test_case(1000, 700, -50, false; "size decreased but not enough")] + #[test_case(1000, 250, -50, true; "size decreased over expectations")] + #[test_case(1000, 250, -1250, false; "demanding more than minus 100 is handled")] + fn test_is_new_size_ok(old: usize, new: usize, max: i32, ok: bool) { + assert_eq!(is_new_size_ok(old, new, max), ok); + } } diff --git a/tooling/nargo_cli/build.rs b/tooling/nargo_cli/build.rs index 740e5ed2052..f0334eaf713 100644 --- a/tooling/nargo_cli/build.rs +++ b/tooling/nargo_cli/build.rs @@ -213,8 +213,13 @@ fn test_{test_name}(force_brillig: ForceBrillig, inliner_aggressiveness: Inliner nargo.arg("--program-dir").arg(test_program_dir); nargo.arg("{test_command}").arg("--force"); nargo.arg("--inliner-aggressiveness").arg(inliner_aggressiveness.0.to_string()); + if force_brillig.0 {{ nargo.arg("--force-brillig"); + + // Set the maximum increase so that part of the optimization is exercised (it might fail). + nargo.arg("--max-bytecode-increase-percent"); + nargo.arg("50"); }} {test_content}