From b35fa25d03164a2cad907c8e8c7b5d8a3775ee98 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Wed, 13 Nov 2024 16:44:45 +0000 Subject: [PATCH] Unroll small Brillig loops --- .../noirc_evaluator/src/ssa/opt/unrolling.rs | 193 +++++++++++++++++- 1 file changed, 186 insertions(+), 7 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index 1077347628f..ca90b6d101f 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -209,8 +209,9 @@ impl Loops { let mut unroll_errors = vec![]; while let Some(next_loop) = self.yet_to_unroll.pop() { if function.runtime().is_brillig() { - // TODO (#6470): Decide whether to unroll this loop. - continue; + if !next_loop.is_small_loop(function, &self.cfg) { + continue; + } } // If we've previously modified a block in this loop we need to refresh the context. // This happens any time we have nested loops. @@ -593,21 +594,54 @@ impl Loop { /// of unrolled instructions times the number of iterations would result in smaller bytecode /// than if we keep the loops with their overheads. fn is_small_loop(&self, function: &Function, cfg: &ControlFlowGraph) -> bool { + self.boilerplate_stats(function, cfg).map(|s| s.is_small()).unwrap_or_default() + } + + /// Collect boilerplate stats if we can figure out the upper and lower bounds of the loop. + fn boilerplate_stats( + &self, + function: &Function, + cfg: &ControlFlowGraph, + ) -> Option { let Ok(Some((lower, upper))) = self.get_const_bounds(function, cfg) else { - return false; + return None; }; let Some(lower) = lower.try_to_u64() else { - return false; + return None; }; let Some(upper) = upper.try_to_u64() else { - return false; + return None; }; - let num_iterations = (upper - lower) as usize; let refs = self.find_pre_header_reference_values(function, cfg); let (loads, stores) = self.count_loads_and_stores(function, &refs); let all_instructions = self.count_all_instructions(function); let useful_instructions = all_instructions - loads - stores - LOOP_BOILERPLATE_COUNT; - useful_instructions * num_iterations < all_instructions + Some(BoilerplateStats { + iterations: (upper - lower) as usize, + loads, + stores, + all_instructions, + useful_instructions, + }) + } +} + +#[derive(Debug)] +struct BoilerplateStats { + iterations: usize, + loads: usize, + stores: usize, + all_instructions: usize, + useful_instructions: usize, +} + +impl BoilerplateStats { + /// A small loop is where if we unroll it into the pre-header then considering the + /// number of iterations we still end up with a smaller bytecode than if we leave + /// the blocks in tact with all the boilerplate involved in jumping, and the extra + /// reference access instructions. + fn is_small(&self) -> bool { + self.useful_instructions * self.iterations < self.all_instructions } } @@ -1014,6 +1048,105 @@ mod tests { assert!(loop0.is_small_loop(function, &loops.cfg)); } + #[test] + fn test_brillig_unroll_small_loop() { + let ssa = brillig_unroll_test_case(); + + // Example taken from an equivalent ACIR program (ie. remove the `unconstrained`) and run + // `cargo run -q -p nargo_cli -- --program-dir . compile --show-ssa` + let expected = " + brillig(inline) fn main f0 { + b0(v0: u32): + v1 = allocate -> &mut u32 + store u32 0 at v1 + v3 = load v1 -> u32 + store v3 at v1 + v4 = load v1 -> u32 + v6 = add v4, u32 1 + store v6 at v1 + v7 = load v1 -> u32 + v9 = add v7, u32 2 + store v9 at v1 + v10 = load v1 -> u32 + v12 = add v10, u32 3 + store v12 at v1 + jmp b1() + b1(): + v13 = load v1 -> u32 + v14 = eq v13, v0 + constrain v13 == v0 + return + } + "; + + let (ssa, errors) = ssa.try_unroll_loops(); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled"); + + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn test_brillig_unroll_6470_small() { + // Few enough iterations so that we can perform the unroll. + let ssa = brillig_unroll_test_case_6470(3); + let (ssa, errors) = ssa.try_unroll_loops(); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_eq!(ssa.main().reachable_blocks().len(), 2, "The loop should be unrolled"); + + // The IDs are shifted by one compared to what the ACIR version printed. + let expected = " + brillig(inline) fn __validate_gt_remainder f0 { + b0(v0: [u64; 6]): + inc_rc v0 + inc_rc [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 + v3 = allocate -> &mut [u64; 6] + store [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 at v3 + v5 = load v3 -> [u64; 6] + v7 = array_get v0, index u32 0 -> u64 + v9 = add v7, u64 1 + v10 = array_set v5, index u32 0, value v9 + store v10 at v3 + v11 = load v3 -> [u64; 6] + v13 = array_get v0, index u32 1 -> u64 + v14 = add v13, u64 1 + v15 = array_set v11, index u32 1, value v14 + store v15 at v3 + v16 = load v3 -> [u64; 6] + v18 = array_get v0, index u32 2 -> u64 + v19 = add v18, u64 1 + v20 = array_set v16, index u32 2, value v19 + store v20 at v3 + jmp b1() + b1(): + v21 = load v3 -> [u64; 6] + dec_rc v0 + return v21 + } + "; + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn test_brillig_unroll_6470_large() { + // More iterations than it can unroll + let ssa = brillig_unroll_test_case_6470(6); + + let function = ssa.main(); + let mut loops = Loops::find_all(function); + let loop0 = loops.yet_to_unroll.pop().unwrap(); + let stats = loop0.boilerplate_stats(function, &loops.cfg).unwrap(); + assert_eq!(stats.is_small(), false); + + let (ssa, errors) = ssa.try_unroll_loops(); + assert_eq!(errors.len(), 0, "Unroll should have no errors"); + assert_eq!( + ssa.main().reachable_blocks().len(), + 4, + "The loop should be considered too costly to unroll" + ); + } + /// Simple test loop: /// ```text /// unconstrained fn main(sum: u32) { @@ -1054,4 +1187,50 @@ mod tests { "; Ssa::from_str(src).unwrap() } + + /// Test case from #6470: + /// ```text + /// unconstrained fn __validate_gt_remainder(a_u60: [u64; 6]) -> [u64; 6] { + /// let mut result_u60: [u64; 6] = [0; 6]; + /// + /// for i in 0..6 { + /// result_u60[i] = a_u60[i] + 1; + /// } + /// + /// result_u60 + /// } + /// ``` + /// The `num_iterations` parameter can be used to make it more costly to inline. + fn brillig_unroll_test_case_6470(num_iterations: usize) -> Ssa { + let src = format!( + " + // After `static_assert` and `assert_constant`: + brillig(inline) fn __validate_gt_remainder f0 {{ + b0(v0: [u64; 6]): + inc_rc v0 + inc_rc [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 + v4 = allocate -> &mut [u64; 6] + store [u64 0, u64 0, u64 0, u64 0, u64 0, u64 0] of u64 at v4 + jmp b1(u32 0) + b1(v1: u32): + v7 = lt v1, u32 {num_iterations} + jmpif v7 then: b3, else: b2 + b3(): + v9 = load v4 -> [u64; 6] + v10 = array_get v0, index v1 -> u64 + v12 = add v10, u64 1 + v13 = array_set v9, index v1, value v12 + v15 = add v1, u32 1 + store v13 at v4 + v16 = add v1, u32 1 + jmp b1(v16) + b2(): + v8 = load v4 -> [u64; 6] + dec_rc v0 + return v8 + }} + " + ); + Ssa::from_str(&src).unwrap() + } }