Skip to content

Commit

Permalink
fix: Do not merge expressions that contain output witnesses (#6757)
Browse files Browse the repository at this point in the history
  • Loading branch information
aakoshh authored Dec 10, 2024
1 parent 9d7aadc commit f9abf72
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 5 deletions.
5 changes: 5 additions & 0 deletions acvm-repo/acvm/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ pub fn compile<F: AcirField>(
acir: Circuit<F>,
expression_width: ExpressionWidth,
) -> (Circuit<F>, AcirTransformationMap) {
if MAX_OPTIMIZER_PASSES == 0 {
let acir_opcode_positions = (0..acir.opcodes.len()).collect::<Vec<_>>();
let transformation_map = AcirTransformationMap::new(&acir_opcode_positions);
return (acir, transformation_map);
}
let mut pass = 0;
let mut prev_opcodes_hash = fxhash::hash64(&acir.opcodes);
let mut prev_acir = acir;
Expand Down
54 changes: 50 additions & 4 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,18 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
self.resolved_blocks.clear();

// Keep track, for each witness, of the gates that use it
let circuit_inputs = circuit.circuit_arguments();
let circuit_io: BTreeSet<Witness> =
circuit.circuit_arguments().union(&circuit.public_inputs().0).cloned().collect();

let mut used_witness: BTreeMap<Witness, BTreeSet<usize>> = BTreeMap::new();
for (i, opcode) in circuit.opcodes.iter().enumerate() {
let witnesses = self.witness_inputs(opcode);
if let Opcode::MemoryInit { block_id, .. } = opcode {
self.resolved_blocks.insert(*block_id, witnesses.clone());
}
for w in witnesses {
// We do not simplify circuit inputs
if !circuit_inputs.contains(&w) {
// We do not simplify circuit inputs and outputs
if !circuit_io.contains(&w) {
used_witness.entry(w).or_default().insert(i);
}
}
Expand Down Expand Up @@ -102,7 +104,7 @@ impl<F: AcirField> MergeExpressionsOptimizer<F> {
let mut witness_list = CircuitSimulator::expr_wit(&expr_use);
witness_list.extend(CircuitSimulator::expr_wit(&expr_define));
for w2 in witness_list {
if !circuit_inputs.contains(&w2) {
if !circuit_io.contains(&w2) {
used_witness.entry(w2).and_modify(|v| {
v.insert(target);
v.remove(&source);
Expand Down Expand Up @@ -326,6 +328,50 @@ mod tests {
check_circuit(circuit);
}

#[test]
fn does_not_eliminate_witnesses_returned_from_circuit() {
let opcodes = vec![
Opcode::AssertZero(Expression {
mul_terms: vec![(FieldElement::from(-1i128), Witness(0), Witness(0))],
linear_combinations: vec![(FieldElement::from(1i128), Witness(1))],
q_c: FieldElement::zero(),
}),
Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
linear_combinations: vec![
(FieldElement::from(-1i128), Witness(1)),
(FieldElement::from(1i128), Witness(2)),
],
q_c: FieldElement::zero(),
}),
];
// Witness(1) could be eliminated because it's only used by 2 opcodes.

let mut private_parameters = BTreeSet::new();
private_parameters.insert(Witness(0));

let mut return_values = BTreeSet::new();
return_values.insert(Witness(1));
return_values.insert(Witness(2));

let circuit = Circuit {
current_witness_index: 2,
expression_width: ExpressionWidth::Bounded { width: 4 },
opcodes,
private_parameters,
public_parameters: PublicInputs::default(),
return_values: PublicInputs(return_values),
assert_messages: Default::default(),
};

let mut merge_optimizer = MergeExpressionsOptimizer::new();
let acir_opcode_positions = vec![0; 20];
let (opcodes, _) =
merge_optimizer.eliminate_intermediate_variable(&circuit, acir_opcode_positions);

assert_eq!(opcodes.len(), 2);
}

#[test]
fn does_not_attempt_to_merge_into_previous_opcodes() {
let opcodes = vec![
Expand Down
1 change: 0 additions & 1 deletion acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> ACVM<'a, F, B> {

pub fn solve_opcode(&mut self) -> ACVMStatus<F> {
let opcode = &self.opcodes[self.instruction_pointer];

let resolution = match opcode {
Opcode::AssertZero(expr) => ExpressionSolver::solve(&mut self.witness_map, expr),
Opcode::BlackBoxFuncCall(bb_func) => blackbox::solve(
Expand Down

0 comments on commit f9abf72

Please sign in to comment.