Skip to content

Commit

Permalink
Unroll small Brillig loops
Browse files Browse the repository at this point in the history
  • Loading branch information
aakoshh committed Nov 13, 2024
1 parent f41ec36 commit b35fa25
Showing 1 changed file with 186 additions and 7 deletions.
193 changes: 186 additions & 7 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ impl Loops {
let mut unroll_errors = vec![];
while let Some(next_loop) = self.yet_to_unroll.pop() {
if function.runtime().is_brillig() {
// TODO (#6470): Decide whether to unroll this loop.
continue;
if !next_loop.is_small_loop(function, &self.cfg) {
continue;
}
}
// If we've previously modified a block in this loop we need to refresh the context.
// This happens any time we have nested loops.
Expand Down Expand Up @@ -593,21 +594,54 @@ impl Loop {
/// of unrolled instructions times the number of iterations would result in smaller bytecode
/// than if we keep the loops with their overheads.
fn is_small_loop(&self, function: &Function, cfg: &ControlFlowGraph) -> bool {
self.boilerplate_stats(function, cfg).map(|s| s.is_small()).unwrap_or_default()
}

/// Collect boilerplate stats if we can figure out the upper and lower bounds of the loop.
fn boilerplate_stats(
&self,
function: &Function,
cfg: &ControlFlowGraph,
) -> Option<BoilerplateStats> {
let Ok(Some((lower, upper))) = self.get_const_bounds(function, cfg) else {
return false;
return None;
};
let Some(lower) = lower.try_to_u64() else {
return false;
return None;
};
let Some(upper) = upper.try_to_u64() else {
return false;
return None;
};
let num_iterations = (upper - lower) as usize;
let refs = self.find_pre_header_reference_values(function, cfg);
let (loads, stores) = self.count_loads_and_stores(function, &refs);
let all_instructions = self.count_all_instructions(function);
let useful_instructions = all_instructions - loads - stores - LOOP_BOILERPLATE_COUNT;
useful_instructions * num_iterations < all_instructions
Some(BoilerplateStats {
iterations: (upper - lower) as usize,
loads,
stores,
all_instructions,
useful_instructions,
})
}
}

#[derive(Debug)]
struct BoilerplateStats {
iterations: usize,
loads: usize,
stores: usize,
all_instructions: usize,
useful_instructions: usize,
}

impl BoilerplateStats {
/// A small loop is where if we unroll it into the pre-header then considering the
/// number of iterations we still end up with a smaller bytecode than if we leave
/// the blocks in tact with all the boilerplate involved in jumping, and the extra
/// reference access instructions.
fn is_small(&self) -> bool {
self.useful_instructions * self.iterations < self.all_instructions
}
}

Expand Down Expand Up @@ -1014,6 +1048,105 @@ mod tests {
assert!(loop0.is_small_loop(function, &loops.cfg));
}

#[test]
fn test_brillig_unroll_small_loop() {
let ssa = brillig_unroll_test_case();

// Example taken from an equivalent ACIR program (ie. remove the `unconstrained`) and run
// `cargo run -q -p nargo_cli -- --program-dir . compile --show-ssa`
let expected = "
brillig(inline) fn main f0 {
b0(v0: u32):
v1 = allocate -> &mut u32
store u32 0 at v1
v3 = load v1 -> u32
store v3 at v1
v4 = load v1 -> u32
v6 = add v4, u32 1
store v6 at v1
v7 = load v1 -> u32
v9 = add v7, u32 2
store v9 at v1
v10 = load v1 -> u32
v12 = add v10, u32 3
store v12 at v1
jmp b1()
b1():
v13 = load v1 -> u32
v14 = eq v13, v0
constrain v13 == v0
return
}
";

let (ssa, errors) = ssa.try_unroll_loops();
assert_eq!(errors.len(), 0, "Unroll should have no errors");
assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled");

assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn test_brillig_unroll_6470_small() {
// Few enough iterations so that we can perform the unroll.
let ssa = brillig_unroll_test_case_6470(3);
let (ssa, errors) = ssa.try_unroll_loops();
assert_eq!(errors.len(), 0, "Unroll should have no errors");
assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled");

// The IDs are shifted by one compared to what the ACIR version printed.
let expected = "
brillig(inline) fn __validate_gt_remainder f0 {
b0(v0: [u64; 6]):
inc_rc v0
inc_rc [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64
v3 = allocate -> &mut [u64; 6]
store [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 at v3
v5 = load v3 -> [u64; 6]
v7 = array_get v0, index u32 0 -> u64
v9 = add v7, u64 1
v10 = array_set v5, index u32 0, value v9
store v10 at v3
v11 = load v3 -> [u64; 6]
v13 = array_get v0, index u32 1 -> u64
v14 = add v13, u64 1
v15 = array_set v11, index u32 1, value v14
store v15 at v3
v16 = load v3 -> [u64; 6]
v18 = array_get v0, index u32 2 -> u64
v19 = add v18, u64 1
v20 = array_set v16, index u32 2, value v19
store v20 at v3
jmp b1()
b1():
v21 = load v3 -> [u64; 6]
dec_rc v0
return v21
}
";
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn test_brillig_unroll_6470_large() {
// More iterations than it can unroll
let ssa = brillig_unroll_test_case_6470(6);

let function = ssa.main();
let mut loops = Loops::find_all(function);
let loop0 = loops.yet_to_unroll.pop().unwrap();
let stats = loop0.boilerplate_stats(function, &loops.cfg).unwrap();
assert_eq!(stats.is_small(), false);

let (ssa, errors) = ssa.try_unroll_loops();
assert_eq!(errors.len(), 0, "Unroll should have no errors");
assert_eq!(
ssa.main().reachable_blocks().len(),
4,
"The loop should be considered too costly to unroll"
);
}

/// Simple test loop:
/// ```text
/// unconstrained fn main(sum: u32) {
Expand Down Expand Up @@ -1054,4 +1187,50 @@ mod tests {
";
Ssa::from_str(src).unwrap()
}

/// Test case from #6470:
/// ```text
/// unconstrained fn __validate_gt_remainder(a_u60: [u64; 6]) -> [u64; 6] {
/// let mut result_u60: [u64; 6] = [0; 6];
///
/// for i in 0..6 {
/// result_u60[i] = a_u60[i] + 1;
/// }
///
/// result_u60
/// }
/// ```
/// The `num_iterations` parameter can be used to make it more costly to inline.
fn brillig_unroll_test_case_6470(num_iterations: usize) -> Ssa {
let src = format!(
"
// After `static_assert` and `assert_constant`:
brillig(inline) fn __validate_gt_remainder f0 {{
b0(v0: [u64; 6]):
inc_rc v0
inc_rc [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64
v4 = allocate -> &mut [u64; 6]
store [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 at v4
jmp b1(u32 0)
b1(v1: u32):
v7 = lt v1, u32 {num_iterations}
jmpif v7 then: b3, else: b2
b3():
v9 = load v4 -> [u64; 6]
v10 = array_get v0, index v1 -> u64
v12 = add v10, u64 1
v13 = array_set v9, index v1, value v12
v15 = add v1, u32 1
store v13 at v4
v16 = add v1, u32 1
jmp b1(v16)
b2():
v8 = load v4 -> [u64; 6]
dec_rc v0
return v8
}}
"
);
Ssa::from_str(&src).unwrap()
}
}

0 comments on commit b35fa25

Please sign in to comment.