Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: pull loop unrolling refactor from sync PR #9975

Merged
merged 4 commits into from
Nov 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 53 additions & 45 deletions noir/noir-repo/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,54 +40,47 @@ use fxhash::FxHashMap as HashMap;
impl Ssa {
/// Loop unrolling can return errors, since ACIR functions need to be fully unrolled.
/// This meta-pass will keep trying to unroll loops and simplifying the SSA until no more errors are found.
#[tracing::instrument(level = "trace", skip(ssa))]
pub(crate) fn unroll_loops_iteratively(mut ssa: Ssa) -> Result<Ssa, RuntimeError> {
// Try to unroll loops first:
let mut unroll_errors;
(ssa, unroll_errors) = ssa.try_to_unroll_loops();

// Keep unrolling until no more errors are found
while !unroll_errors.is_empty() {
let prev_unroll_err_count = unroll_errors.len();

// Simplify the SSA before retrying

// Do a mem2reg after the last unroll to aid simplify_cfg
ssa = ssa.mem2reg();
ssa = ssa.simplify_cfg();
// Do another mem2reg after simplify_cfg to aid the next unroll
ssa = ssa.mem2reg();

// Unroll again
(ssa, unroll_errors) = ssa.try_to_unroll_loops();
// If we didn't manage to unroll any more loops, exit
if unroll_errors.len() >= prev_unroll_err_count {
return Err(unroll_errors.swap_remove(0));
let acir_functions = ssa.functions.iter_mut().filter(|(_, func)| {
// Loop unrolling in brillig can lead to a code explosion currently. This can
// also be true for ACIR, but we have no alternative to unrolling in ACIR.
// Brillig also generally prefers smaller code rather than faster code.
!matches!(func.runtime(), RuntimeType::Brillig(_))
});

for (_, function) in acir_functions {
// Try to unroll loops first:
let mut unroll_errors = function.try_to_unroll_loops();

// Keep unrolling until no more errors are found
while !unroll_errors.is_empty() {
let prev_unroll_err_count = unroll_errors.len();

// Simplify the SSA before retrying

// Do a mem2reg after the last unroll to aid simplify_cfg
function.mem2reg();
function.simplify_function();
// Do another mem2reg after simplify_cfg to aid the next unroll
function.mem2reg();

// Unroll again
unroll_errors = function.try_to_unroll_loops();
// If we didn't manage to unroll any more loops, exit
if unroll_errors.len() >= prev_unroll_err_count {
return Err(unroll_errors.swap_remove(0));
}
}
}
Ok(ssa)
}

/// Tries to unroll all loops in each SSA function.
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
/// Returns the ssa along with all unrolling errors encountered
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn try_to_unroll_loops(mut self) -> (Ssa, Vec<RuntimeError>) {
let mut errors = vec![];
for function in self.functions.values_mut() {
function.try_to_unroll_loops(&mut errors);
}
(self, errors)
Ok(ssa)
}
}

impl Function {
pub(crate) fn try_to_unroll_loops(&mut self, errors: &mut Vec<RuntimeError>) {
// Loop unrolling in brillig can lead to a code explosion currently. This can
// also be true for ACIR, but we have no alternative to unrolling in ACIR.
// Brillig also generally prefers smaller code rather than faster code.
if !matches!(self.runtime(), RuntimeType::Brillig(_)) {
errors.extend(find_all_loops(self).unroll_each_loop(self));
}
fn try_to_unroll_loops(&mut self) -> Vec<RuntimeError> {
find_all_loops(self).unroll_each_loop(self)
}
}

Expand Down Expand Up @@ -507,11 +500,26 @@ impl<'f> LoopIteration<'f> {

#[cfg(test)]
mod tests {
use crate::ssa::{
function_builder::FunctionBuilder,
ir::{instruction::BinaryOp, map::Id, types::Type},
use crate::{
errors::RuntimeError,
ssa::{
function_builder::FunctionBuilder,
ir::{instruction::BinaryOp, map::Id, types::Type},
},
};

use super::Ssa;

/// Tries to unroll all loops in each SSA function.
/// If any loop cannot be unrolled, it is left as-is or in a partially unrolled state.
fn try_to_unroll_loops(mut ssa: Ssa) -> (Ssa, Vec<RuntimeError>) {
let mut errors = vec![];
for function in ssa.functions.values_mut() {
errors.extend(function.try_to_unroll_loops());
}
(ssa, errors)
}

#[test]
fn unroll_nested_loops() {
// fn main() {
Expand Down Expand Up @@ -630,7 +638,7 @@ mod tests {
// }
// The final block count is not 1 because unrolling creates some unnecessary jmps.
// If a simplify cfg pass is ran afterward, the expected block count will be 1.
let (ssa, errors) = ssa.try_to_unroll_loops();
let (ssa, errors) = try_to_unroll_loops(ssa);
assert_eq!(errors.len(), 0, "All loops should be unrolled");
assert_eq!(ssa.main().reachable_blocks().len(), 5);
}
Expand Down Expand Up @@ -680,7 +688,7 @@ mod tests {
assert_eq!(ssa.main().reachable_blocks().len(), 4);

// Expected that we failed to unroll the loop
let (_, errors) = ssa.try_to_unroll_loops();
let (_, errors) = try_to_unroll_loops(ssa);
assert_eq!(errors.len(), 1, "Expected to fail to unroll loop");
}
}
Loading