Skip to content

Commit

Permalink
chore: optimise older opcodes in reverse order (#6476)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
Co-authored-by: TomAFrench <tom@tomfren.ch>
  • Loading branch information
3 people authored Dec 6, 2024
1 parent 1cd2b4d commit da9a74a
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 64 deletions.
152 changes: 89 additions & 63 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,36 @@ use acir::{

use crate::compiler::CircuitSimulator;

pub(crate) struct MergeExpressionsOptimizer {
pub(crate) struct MergeExpressionsOptimizer<F> {
resolved_blocks: HashMap<BlockId, BTreeSet<Witness>>,
modified_gates: HashMap<usize, Opcode<F>>,
deleted_gates: BTreeSet<usize>,
}

impl MergeExpressionsOptimizer {
impl<F: AcirField> MergeExpressionsOptimizer<F> {
pub(crate) fn new() -> Self {
MergeExpressionsOptimizer { resolved_blocks: HashMap::new() }
MergeExpressionsOptimizer {
resolved_blocks: HashMap::new(),
modified_gates: HashMap::new(),
deleted_gates: BTreeSet::new(),
}
}
/// This pass analyzes the circuit and identifies intermediate variables that are
/// only used in two gates. It then merges the gate that produces the
/// intermediate variable into the second one that uses it
/// Note: This pass is only relevant for backends that can handle unlimited width
pub(crate) fn eliminate_intermediate_variable<F: AcirField>(
pub(crate) fn eliminate_intermediate_variable(
&mut self,
circuit: &Circuit<F>,
acir_opcode_positions: Vec<usize>,
) -> (Vec<Opcode<F>>, Vec<usize>) {
// Initialization
self.modified_gates.clear();
self.deleted_gates.clear();
self.resolved_blocks.clear();

// Keep track, for each witness, of the gates that use it
let circuit_inputs = circuit.circuit_arguments();
self.resolved_blocks = HashMap::new();
let mut used_witness: BTreeMap<Witness, BTreeSet<usize>> = BTreeMap::new();
for (i, opcode) in circuit.opcodes.iter().enumerate() {
let witnesses = self.witness_inputs(opcode);
Expand All @@ -46,80 +56,89 @@ impl MergeExpressionsOptimizer {
}
}

let mut modified_gates: HashMap<usize, Opcode<F>> = HashMap::new();
let mut new_circuit = Vec::new();
let mut new_acir_opcode_positions = Vec::new();
// For each opcode, try to get a target opcode to merge with
for (i, (opcode, opcode_position)) in
circuit.opcodes.iter().zip(acir_opcode_positions).enumerate()
{
for (i, opcode) in circuit.opcodes.iter().enumerate() {
if !matches!(opcode, Opcode::AssertZero(_)) {
new_circuit.push(opcode.clone());
new_acir_opcode_positions.push(opcode_position);
continue;
}
let opcode = modified_gates.get(&i).unwrap_or(opcode).clone();
let mut to_keep = true;
let input_witnesses = self.witness_inputs(&opcode);
for w in input_witnesses {
let Some(gates_using_w) = used_witness.get(&w) else {
continue;
};
// We only consider witness which are used in exactly two arithmetic gates
if gates_using_w.len() == 2 {
let first = *gates_using_w.first().expect("gates_using_w.len == 2");
let second = *gates_using_w.last().expect("gates_using_w.len == 2");
let b = if second == i {
first
} else {
// sanity check
assert!(i == first);
second
if let Some(opcode) = self.get_opcode(i, circuit) {
let input_witnesses = self.witness_inputs(&opcode);
for w in input_witnesses {
let Some(gates_using_w) = used_witness.get(&w) else {
continue;
};

let second_gate = modified_gates.get(&b).unwrap_or(&circuit.opcodes[b]);
if let (Opcode::AssertZero(expr_define), Opcode::AssertZero(expr_use)) =
(&opcode, second_gate)
{
// We cannot merge an expression into an earlier opcode, because this
// would break the 'execution ordering' of the opcodes
// This case can happen because a previous merge would change an opcode
// and eliminate a witness from it, giving new opportunities for this
// witness to be used in only two expressions
// TODO: the missed optimization for the i>b case can be handled by
// - doing this pass again until there is no change, or
// - merging 'b' into 'i' instead
if i < b {
if let Some(expr) = Self::merge(expr_use, expr_define, w) {
modified_gates.insert(b, Opcode::AssertZero(expr));
to_keep = false;
// Update the 'used_witness' map to account for the merge.
for w2 in CircuitSimulator::expr_wit(expr_define) {
if !circuit_inputs.contains(&w2) {
let v = used_witness.entry(w2).or_default();
v.insert(b);
v.remove(&i);
// We only consider witness which are used in exactly two arithmetic gates
if gates_using_w.len() == 2 {
let first = *gates_using_w.first().expect("gates_using_w.len == 2");
let second = *gates_using_w.last().expect("gates_using_w.len == 2");
let b = if second == i {
first
} else {
// sanity check
assert!(i == first);
second
};
// Merge the opcode with smaller index into the other one
// by updating modified_gates/deleted_gates/used_witness
// returns false if it could not merge them
let mut merge_opcodes = |op1, op2| -> bool {
if op1 == op2 {
return false;
}
let (source, target) = if op1 < op2 { (op1, op2) } else { (op2, op1) };
let source_opcode = self.get_opcode(source, circuit);
let target_opcode = self.get_opcode(target, circuit);
if let (
Some(Opcode::AssertZero(expr_use)),
Some(Opcode::AssertZero(expr_define)),
) = (target_opcode, source_opcode)
{
if let Some(expr) =
Self::merge_expression(&expr_use, &expr_define, w)
{
self.modified_gates.insert(target, Opcode::AssertZero(expr));
self.deleted_gates.insert(source);
// Update the 'used_witness' map to account for the merge.
let mut witness_list = CircuitSimulator::expr_wit(&expr_use);
witness_list.extend(CircuitSimulator::expr_wit(&expr_define));
for w2 in witness_list {
if !circuit_inputs.contains(&w2) {
used_witness.entry(w2).and_modify(|v| {
v.insert(target);
v.remove(&source);
});
}
}
return true;
}
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
false
};

if merge_opcodes(b, i) {
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
}
}
}
}

// Construct the new circuit from modified/deleted gates
let mut new_circuit = Vec::new();
let mut new_acir_opcode_positions = Vec::new();

if to_keep {
let opcode = modified_gates.get(&i).cloned().unwrap_or(opcode);
new_circuit.push(opcode);
new_acir_opcode_positions.push(opcode_position);
for (i, opcode_position) in acir_opcode_positions.iter().enumerate() {
if let Some(op) = self.get_opcode(i, circuit) {
new_circuit.push(op);
new_acir_opcode_positions.push(*opcode_position);
}
}
(new_circuit, new_acir_opcode_positions)
}

fn brillig_input_wit<F>(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
fn brillig_input_wit(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
match input {
BrilligInputs::Single(expr) => {
Expand Down Expand Up @@ -152,7 +171,7 @@ impl MergeExpressionsOptimizer {
}

// Returns the input witnesses used by the opcode
fn witness_inputs<F: AcirField>(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
fn witness_inputs(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
match opcode {
Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr),
Opcode::BlackBoxFuncCall(bb_func) => {
Expand Down Expand Up @@ -198,7 +217,7 @@ impl MergeExpressionsOptimizer {

// Merge 'expr' into 'target' via Gaussian elimination on 'w'
// Returns None if the expressions cannot be merged
fn merge<F: AcirField>(
fn merge_expression(
target: &Expression<F>,
expr: &Expression<F>,
w: Witness,
Expand Down Expand Up @@ -226,6 +245,13 @@ impl MergeExpressionsOptimizer {
}
None
}

fn get_opcode(&self, g: usize, circuit: &Circuit<F>) -> Option<Opcode<F>> {
if self.deleted_gates.contains(&g) {
return None;
}
self.modified_gates.get(&g).or(circuit.opcodes.get(g)).cloned()
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn main(x: Field) {
value += term2;
value.assert_max_bit_size::<1>();

// Regression test for Aztec Packages issue #6451
// Regression test for #6447 (Aztec Packages issue #9488)
let y = unsafe { empty(x + 1) };
let z = y + x + 1;
let z1 = z + y;
Expand Down

0 comments on commit da9a74a

Please sign in to comment.