Skip to content

Commit

Permalink
feat(ssa): Option to set the maximum acceptable Brillig bytecode incr…
Browse files Browse the repository at this point in the history
…ease in unrolling (#6641)
  • Loading branch information
aakoshh authored Nov 29, 2024
1 parent 594aad2 commit 4ff3081
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 33 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 10 additions & 1 deletion compiler/noirc_driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,19 @@ pub struct CompileOptions {
#[arg(long)]
pub skip_underconstrained_check: bool,

/// Setting to decide on an inlining strategy for brillig functions.
/// Setting to decide on an inlining strategy for Brillig functions.
/// A more aggressive inliner should generate larger programs but more optimized
/// A less aggressive inliner should generate smaller programs
#[arg(long, hide = true, allow_hyphen_values = true, default_value_t = i64::MAX)]
pub inliner_aggressiveness: i64,

/// Setting the maximum acceptable increase in Brillig bytecode size due to
/// unrolling small loops. When left empty, any change is accepted as long
/// as it required fewer SSA instructions.
/// A higher value results in fewer jumps but a larger program.
/// A lower value keeps the original program if it was smaller, even if it has more jumps.
#[arg(long, hide = true, allow_hyphen_values = true)]
pub max_bytecode_increase_percent: Option<i32>,
}

pub fn parse_expression_width(input: &str) -> Result<ExpressionWidth, std::io::Error> {
Expand Down Expand Up @@ -589,6 +597,7 @@ pub fn compile_no_check(
emit_ssa: if options.emit_ssa { Some(context.package_build_path.clone()) } else { None },
skip_underconstrained_check: options.skip_underconstrained_check,
inliner_aggressiveness: options.inliner_aggressiveness,
max_bytecode_increase_percent: options.max_bytecode_increase_percent,
};

let SsaProgramArtifact { program, debug, warnings, names, brillig_names, error_types, .. } =
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_evaluator/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ cfg-if.workspace = true
proptest.workspace = true
similar-asserts.workspace = true
num-traits.workspace = true
test-case.workspace = true

[features]
bn254 = ["noirc_frontend/bn254"]
Expand Down
8 changes: 3 additions & 5 deletions compiler/noirc_evaluator/src/brillig/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use self::{
},
};
use crate::ssa::{
ir::function::{Function, FunctionId, RuntimeType},
ir::function::{Function, FunctionId},
ssa_gen::Ssa,
};
use fxhash::FxHashMap as HashMap;
Expand Down Expand Up @@ -59,17 +59,15 @@ impl std::ops::Index<FunctionId> for Brillig {
}

impl Ssa {
/// Compile to brillig brillig functions and ACIR functions reachable from them
/// Compile Brillig functions and ACIR functions reachable from them
#[tracing::instrument(level = "trace", skip_all)]
pub(crate) fn to_brillig(&self, enable_debug_trace: bool) -> Brillig {
// Collect all the function ids that are reachable from brillig
// That means all the functions marked as brillig and ACIR functions called by them
let brillig_reachable_function_ids = self
.functions
.iter()
.filter_map(|(id, func)| {
matches!(func.runtime(), RuntimeType::Brillig(_)).then_some(*id)
})
.filter_map(|(id, func)| func.runtime().is_brillig().then_some(*id))
.collect::<BTreeSet<_>>();

let mut brillig = Brillig::default();
Expand Down
19 changes: 13 additions & 6 deletions compiler/noirc_evaluator/src/ssa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pub struct SsaEvaluatorOptions {

/// The higher the value, the more inlined brillig functions will be.
pub inliner_aggressiveness: i64,

/// Maximum accepted percentage increase in the Brillig bytecode size after unrolling loops.
/// When `None` the size increase check is skipped altogether and any decrease in the SSA
/// instruction count is accepted.
pub max_bytecode_increase_percent: Option<i32>,
}

pub(crate) struct ArtifactsAndWarnings(Artifacts, Vec<SsaReport>);
Expand Down Expand Up @@ -104,7 +109,10 @@ pub(crate) fn optimize_into_acir(
"After `static_assert` and `assert_constant`:",
)?
.run_pass(Ssa::loop_invariant_code_motion, "After Loop Invariant Code Motion:")
.try_run_pass(Ssa::unroll_loops_iteratively, "After Unrolling:")?
.try_run_pass(
|ssa| ssa.unroll_loops_iteratively(options.max_bytecode_increase_percent),
"After Unrolling:",
)?
.run_pass(Ssa::simplify_cfg, "After Simplifying (2nd):")
.run_pass(Ssa::flatten_cfg, "After Flattening:")
.run_pass(Ssa::remove_bit_shifts, "After Removing Bit Shifts:")
Expand Down Expand Up @@ -450,11 +458,10 @@ impl SsaBuilder {
}

/// The same as `run_pass` but for passes that may fail
fn try_run_pass(
mut self,
pass: fn(Ssa) -> Result<Ssa, RuntimeError>,
msg: &str,
) -> Result<Self, RuntimeError> {
fn try_run_pass<F>(mut self, pass: F, msg: &str) -> Result<Self, RuntimeError>
where
F: FnOnce(Ssa) -> Result<Ssa, RuntimeError>,
{
self.ssa = time(msg, self.print_codegen_timings, || pass(self.ssa))?;
Ok(self.print(msg))
}
Expand Down
6 changes: 6 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,12 @@ impl Function {
}
}

impl Clone for Function {
fn clone(&self) -> Self {
Function::clone_with_id(self.id(), self)
}
}

impl std::fmt::Display for RuntimeType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down
137 changes: 116 additions & 21 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
//! When unrolling ACIR code, we remove reference count instructions because they are
//! only used by Brillig bytecode.
use acvm::{acir::AcirField, FieldElement};
use im::HashSet;

use crate::{
brillig::brillig_gen::convert_ssa_function,
errors::RuntimeError,
ssa::{
ir::{
Expand All @@ -37,38 +39,60 @@ use crate::{
ssa_gen::Ssa,
},
};
use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet};
use fxhash::FxHashMap as HashMap;

impl Ssa {
/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled.
/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found.
#[tracing::instrument(level = "trace", skip(ssa))]
pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result<Ssa, RuntimeError> {
for (_, function) in ssa.functions.iter_mut() {
///
/// The `max_bytecode_incr_pct`, when given, is used to limit the growth of the Brillig bytecode size
/// after unrolling small loops to some percentage of the original loop. For example a value of 150 would
/// mean the new loop can be 150% (ie. 2.5 times) larger than the original loop. It will still contain
/// fewer SSA instructions, but that can still result in more Brillig opcodes.
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn unroll_loops_iteratively(
mut self: Ssa,
max_bytecode_increase_percent: Option<i32>,
) -> Result<Ssa, RuntimeError> {
for (_, function) in self.functions.iter_mut() {
// Take a snapshot of the function to compare byte size increase,
// but only if the setting indicates we have to, otherwise skip it.
let orig_func_and_max_incr_pct = max_bytecode_increase_percent
.filter(|_| function.runtime().is_brillig())
.map(|max_incr_pct| (function.clone(), max_incr_pct));

// Try to unroll loops first:
let mut unroll_errors = function.try_unroll_loops();
let (mut has_unrolled, mut unroll_errors) = function.try_unroll_loops();

// Keep unrolling until no more errors are found
while !unroll_errors.is_empty() {
let prev_unroll_err_count = unroll_errors.len();

// Simplify the SSA before retrying

// Do a mem2reg after the last unroll to aid simplify_cfg
function.mem2reg();
function.simplify_function();
// Do another mem2reg after simplify_cfg to aid the next unroll
function.mem2reg();
simplify_between_unrolls(function);

// Unroll again
unroll_errors = function.try_unroll_loops();
let (new_unrolled, new_errors) = function.try_unroll_loops();
unroll_errors = new_errors;
has_unrolled |= new_unrolled;

// If we didn't manage to unroll any more loops, exit
if unroll_errors.len() >= prev_unroll_err_count {
return Err(unroll_errors.swap_remove(0));
}
}

if has_unrolled {
if let Some((orig_function, max_incr_pct)) = orig_func_and_max_incr_pct {
let new_size = brillig_bytecode_size(function);
let orig_size = brillig_bytecode_size(&orig_function);
if !is_new_size_ok(orig_size, new_size, max_incr_pct) {
*function = orig_function;
}
}
}
}
Ok(ssa)
Ok(self)
}
}

Expand All @@ -77,7 +101,7 @@ impl Function {
// This can also be true for ACIR, but we have no alternative to unrolling in ACIR.
// Brillig also generally prefers smaller code rather than faster code,
// so we only attempt to unroll small loops, which we decide on a case-by-case basis.
fn try_unroll_loops(&mut self) -> Vec<RuntimeError> {
fn try_unroll_loops(&mut self) -> (bool, Vec<RuntimeError>) {
Loops::find_all(self).unroll_each(self)
}
}
Expand Down Expand Up @@ -170,8 +194,10 @@ impl Loops {

/// Unroll all loops within a given function.
/// Any loops which fail to be unrolled (due to using non-constant indices) will be unmodified.
fn unroll_each(mut self, function: &mut Function) -> Vec<RuntimeError> {
/// Returns whether any blocks have been modified
fn unroll_each(mut self, function: &mut Function) -> (bool, Vec<RuntimeError>) {
let mut unroll_errors = vec![];
let mut has_unrolled = false;
while let Some(next_loop) = self.yet_to_unroll.pop() {
if function.runtime().is_brillig() && !next_loop.is_small_loop(function, &self.cfg) {
continue;
Expand All @@ -181,21 +207,25 @@ impl Loops {
if next_loop.blocks.iter().any(|block| self.modified_blocks.contains(block)) {
let mut new_loops = Self::find_all(function);
new_loops.failed_to_unroll = self.failed_to_unroll;
return unroll_errors.into_iter().chain(new_loops.unroll_each(function)).collect();
let (new_unrolled, new_errors) = new_loops.unroll_each(function);
return (has_unrolled || new_unrolled, [unroll_errors, new_errors].concat());
}

// Don't try to unroll the loop again if it is known to fail
if !self.failed_to_unroll.contains(&next_loop.header) {
match next_loop.unroll(function, &self.cfg) {
Ok(_) => self.modified_blocks.extend(next_loop.blocks),
Ok(_) => {
has_unrolled = true;
self.modified_blocks.extend(next_loop.blocks);
}
Err(call_stack) => {
self.failed_to_unroll.insert(next_loop.header);
unroll_errors.push(RuntimeError::UnknownLoopBound { call_stack });
}
}
}
}
unroll_errors
(has_unrolled, unroll_errors)
}
}

Expand Down Expand Up @@ -947,21 +977,59 @@ impl<'f> LoopIteration<'f> {
}
}

/// Unrolling leaves some duplicate instructions which can potentially be removed.
fn simplify_between_unrolls(function: &mut Function) {
// Do a mem2reg after the last unroll to aid simplify_cfg
function.mem2reg();
function.simplify_function();
// Do another mem2reg after simplify_cfg to aid the next unroll
function.mem2reg();
}

/// Convert the function to Brillig bytecode and return the resulting size.
fn brillig_bytecode_size(function: &Function) -> usize {
// We need to do some SSA passes in order for the conversion to be able to go ahead,
// otherwise we can hit `unreachable!()` instructions in `convert_ssa_instruction`.
// Creating a clone so as not to modify the originals.
let mut temp = function.clone();

// Might as well give it the best chance.
simplify_between_unrolls(&mut temp);

// This is to try to prevent hitting ICE.
temp.dead_instruction_elimination(false);

convert_ssa_function(&temp, false).byte_code.len()
}

/// Decide if the new bytecode size is acceptable, compared to the original.
///
/// The maximum increase can be expressed as a negative value if we demand a decrease.
/// (Values -100 and under mean the new size should be 0).
fn is_new_size_ok(orig_size: usize, new_size: usize, max_incr_pct: i32) -> bool {
let max_size_pct = 100i32.saturating_add(max_incr_pct).max(0) as usize;
let max_size = orig_size.saturating_mul(max_size_pct);
new_size.saturating_mul(100) <= max_size
}

#[cfg(test)]
mod tests {
use acvm::FieldElement;
use test_case::test_case;

use crate::errors::RuntimeError;
use crate::ssa::{ir::value::ValueId, opt::assert_normalized_ssa_equals, Ssa};

use super::{BoilerplateStats, Loops};
use super::{is_new_size_ok, BoilerplateStats, Loops};

/// Tries to unroll all loops in each SSA function.
/// Tries to unroll all loops in each SSA function once, calling the `Function` directly,
/// bypassing the iterative loop done by the SSA which does further optimisations.
///
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
fn try_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec<RuntimeError>) {
let mut errors = vec![];
for function in ssa.functions.values_mut() {
errors.extend(function.try_unroll_loops());
errors.extend(function.try_unroll_loops().1);
}
(ssa, errors)
}
Expand Down Expand Up @@ -1221,9 +1289,26 @@ mod tests {

let (ssa, errors) = try_unroll_loops(ssa);
assert_eq!(errors.len(), 0, "Unroll should have no errors");
// Check that it's still the original
assert_normalized_ssa_equals(ssa, parse_ssa().to_string().as_str());
}

#[test]
fn test_brillig_unroll_iteratively_respects_max_increase() {
let ssa = brillig_unroll_test_case();
let ssa = ssa.unroll_loops_iteratively(Some(-90)).unwrap();
// Check that it's still the original
assert_normalized_ssa_equals(ssa, brillig_unroll_test_case().to_string().as_str());
}

#[test]
fn test_brillig_unroll_iteratively_with_large_max_increase() {
let ssa = brillig_unroll_test_case();
let ssa = ssa.unroll_loops_iteratively(Some(50)).unwrap();
// Check that it did the unroll
assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled");
}

/// Test that `break` and `continue` stop unrolling without any panic.
#[test]
fn test_brillig_unroll_break_and_continue() {
Expand Down Expand Up @@ -1377,4 +1462,14 @@ mod tests {
let loop0 = loops.yet_to_unroll.pop().expect("there should be a loop");
loop0.boilerplate_stats(function, &loops.cfg).expect("there should be stats")
}

#[test_case(1000, 700, 50, true; "size decreased")]
#[test_case(1000, 1500, 50, true; "size increased just by the max")]
#[test_case(1000, 1501, 50, false; "size increased over the max")]
#[test_case(1000, 700, -50, false; "size decreased but not enough")]
#[test_case(1000, 250, -50, true; "size decreased over expectations")]
#[test_case(1000, 250, -1250, false; "demanding more than minus 100 is handled")]
fn test_is_new_size_ok(old: usize, new: usize, max: i32, ok: bool) {
assert_eq!(is_new_size_ok(old, new, max), ok);
}
}
5 changes: 5 additions & 0 deletions tooling/nargo_cli/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,13 @@ fn test_{test_name}(force_brillig: ForceBrillig, inliner_aggressiveness: Inliner
nargo.arg("--program-dir").arg(test_program_dir);
nargo.arg("{test_command}").arg("--force");
nargo.arg("--inliner-aggressiveness").arg(inliner_aggressiveness.0.to_string());
if force_brillig.0 {{
nargo.arg("--force-brillig");
// Set the maximum increase so that part of the optimization is exercised (it might fail).
nargo.arg("--max-bytecode-increase-percent");
nargo.arg("50");
}}
{test_content}
Expand Down

0 comments on commit 4ff3081

Please sign in to comment.