Skip to content

Commit

Permalink
feat(acvm): separate ACVM optimizations and transformations (#2979)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomAFrench authored Oct 4, 2023
1 parent 1b5124b commit 5865d1a
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 210 deletions.
215 changes: 7 additions & 208 deletions acvm-repo/acvm/src/compiler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
use acir::{
circuit::{
brillig::BrilligOutputs, directives::Directive, opcodes::UnsupportedMemoryOpcode, Circuit,
Opcode, OpcodeLocation,
},
native_types::{Expression, Witness},
BlackBoxFunc, FieldElement,
circuit::{opcodes::UnsupportedMemoryOpcode, Circuit, Opcode, OpcodeLocation},
BlackBoxFunc,
};
use indexmap::IndexMap;
use thiserror::Error;

use crate::Language;
Expand All @@ -15,8 +10,9 @@ use crate::Language;
mod optimizers;
mod transformers;

use optimizers::{GeneralOptimizer, RangeOptimizer};
use transformers::{CSatTransformer, FallbackTransformer, R1CSTransformer};
pub use optimizers::optimize;
pub use transformers::transform;
use transformers::transform_internal;

#[derive(PartialEq, Eq, Debug, Error)]
pub enum CompileError {
Expand Down Expand Up @@ -77,204 +73,7 @@ pub fn compile(
np_language: Language,
is_opcode_supported: impl Fn(&Opcode) -> bool,
) -> Result<(Circuit, AcirTransformationMap), CompileError> {
// Instantiate the optimizer.
// Currently the optimizer and reducer are one in the same
// for CSAT
let (acir, AcirTransformationMap { acir_opcode_positions }) = optimize(acir);

// Track original acir opcode positions throughout the transformation passes of the compilation
// by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert)
let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect();

// Fallback transformer pass
let (acir, acir_opcode_positions) =
FallbackTransformer::transform(acir, is_opcode_supported, acir_opcode_positions)?;

// General optimizer pass
let mut opcodes: Vec<Opcode> = Vec::new();
for opcode in acir.opcodes {
match opcode {
Opcode::Arithmetic(arith_expr) => {
opcodes.push(Opcode::Arithmetic(GeneralOptimizer::optimize(arith_expr)));
}
other_opcode => opcodes.push(other_opcode),
};
}
let acir = Circuit { opcodes, ..acir };

// Range optimization pass
let range_optimizer = RangeOptimizer::new(acir);
let (mut acir, acir_opcode_positions) =
range_optimizer.replace_redundant_ranges(acir_opcode_positions);

let mut transformer = match &np_language {
crate::Language::R1CS => {
let transformation_map = AcirTransformationMap { acir_opcode_positions };
acir.assert_messages =
transform_assert_messages(acir.assert_messages, &transformation_map);
let transformer = R1CSTransformer::new(acir);
return Ok((transformer.transform(), transformation_map));
}
crate::Language::PLONKCSat { width } => {
let mut csat = CSatTransformer::new(*width);
for value in acir.circuit_arguments() {
csat.mark_solvable(value);
}
csat
}
};

// TODO: the code below is only for CSAT transformer
// TODO it may be possible to refactor it in a way that we do not need to return early from the r1cs
// TODO or at the very least, we could put all of it inside of CSatOptimizer pass

let mut new_acir_opcode_positions: Vec<usize> = Vec::with_capacity(acir_opcode_positions.len());
// Optimize the arithmetic gates by reducing them into the correct width and
// creating intermediate variables when necessary
let mut transformed_opcodes = Vec::new();

let mut next_witness_index = acir.current_witness_index + 1;
// maps a normalized expression to the intermediate variable which represents the expression, along with its 'norm'
// the 'norm' is simply the value of the first non zero coefficient in the expression, taken from the linear terms, or quadratic terms if there is none.
let mut intermediate_variables: IndexMap<Expression, (FieldElement, Witness)> = IndexMap::new();
for (index, opcode) in acir.opcodes.iter().enumerate() {
match opcode {
Opcode::Arithmetic(arith_expr) => {
let len = intermediate_variables.len();

let arith_expr = transformer.transform(
arith_expr.clone(),
&mut intermediate_variables,
&mut next_witness_index,
);

// Update next_witness counter
next_witness_index += (intermediate_variables.len() - len) as u32;
let mut new_opcodes = Vec::new();
for (g, (norm, w)) in intermediate_variables.iter().skip(len) {
// de-normalize
let mut intermediate_opcode = g * *norm;
// constrain the intermediate opcode to the intermediate variable
intermediate_opcode.linear_combinations.push((-FieldElement::one(), *w));
intermediate_opcode.sort();
new_opcodes.push(intermediate_opcode);
}
new_opcodes.push(arith_expr);
for opcode in new_opcodes {
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(Opcode::Arithmetic(opcode));
}
}
Opcode::BlackBoxFuncCall(func) => {
match func {
acir::circuit::opcodes::BlackBoxFuncCall::AND { output, .. }
| acir::circuit::opcodes::BlackBoxFuncCall::XOR { output, .. } => {
transformer.mark_solvable(*output);
}
acir::circuit::opcodes::BlackBoxFuncCall::RANGE { .. } => (),
acir::circuit::opcodes::BlackBoxFuncCall::SHA256 { outputs, .. }
| acir::circuit::opcodes::BlackBoxFuncCall::Keccak256 { outputs, .. }
| acir::circuit::opcodes::BlackBoxFuncCall::Keccak256VariableLength {
outputs,
..
}
| acir::circuit::opcodes::BlackBoxFuncCall::RecursiveAggregation {
output_aggregation_object: outputs,
..
}
| acir::circuit::opcodes::BlackBoxFuncCall::Blake2s { outputs, .. } => {
for witness in outputs {
transformer.mark_solvable(*witness);
}
}
acir::circuit::opcodes::BlackBoxFuncCall::FixedBaseScalarMul {
outputs,
..
}
| acir::circuit::opcodes::BlackBoxFuncCall::Pedersen { outputs, .. } => {
transformer.mark_solvable(outputs.0);
transformer.mark_solvable(outputs.1);
}
acir::circuit::opcodes::BlackBoxFuncCall::HashToField128Security {
output,
..
}
| acir::circuit::opcodes::BlackBoxFuncCall::EcdsaSecp256k1 { output, .. }
| acir::circuit::opcodes::BlackBoxFuncCall::EcdsaSecp256r1 { output, .. }
| acir::circuit::opcodes::BlackBoxFuncCall::SchnorrVerify { output, .. } => {
transformer.mark_solvable(*output);
}
}

new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode.clone());
}
Opcode::Directive(directive) => {
match directive {
Directive::Quotient(quotient_directive) => {
transformer.mark_solvable(quotient_directive.q);
transformer.mark_solvable(quotient_directive.r);
}
Directive::ToLeRadix { b, .. } => {
for witness in b {
transformer.mark_solvable(*witness);
}
}
Directive::PermutationSort { bits, .. } => {
for witness in bits {
transformer.mark_solvable(*witness);
}
}
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode.clone());
}
Opcode::MemoryInit { .. } => {
// `MemoryInit` does not write values to the `WitnessMap`
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode.clone());
}
Opcode::MemoryOp { op, .. } => {
for (_, witness1, witness2) in &op.value.mul_terms {
transformer.mark_solvable(*witness1);
transformer.mark_solvable(*witness2);
}
for (_, witness) in &op.value.linear_combinations {
transformer.mark_solvable(*witness);
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode.clone());
}
Opcode::Brillig(brillig) => {
for output in &brillig.outputs {
match output {
BrilligOutputs::Simple(w) => transformer.mark_solvable(*w),
BrilligOutputs::Array(v) => {
for witness in v {
transformer.mark_solvable(*witness);
}
}
}
}
new_acir_opcode_positions.push(acir_opcode_positions[index]);
transformed_opcodes.push(opcode.clone());
}
}
}

let current_witness_index = next_witness_index - 1;

let transformation_map =
AcirTransformationMap { acir_opcode_positions: new_acir_opcode_positions };

let acir = Circuit {
current_witness_index,
opcodes: transformed_opcodes,
// The optimizer does not add new public inputs
private_parameters: acir.private_parameters,
public_parameters: acir.public_parameters,
return_values: acir.return_values,
assert_messages: transform_assert_messages(acir.assert_messages, &transformation_map),
};

Ok((acir, transformation_map))
transform_internal(acir, np_language, is_opcode_supported, acir_opcode_positions)
}
32 changes: 32 additions & 0 deletions acvm-repo/acvm/src/compiler/optimizers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,37 @@
use acir::circuit::{Circuit, Opcode};

mod general;
mod redundant_range;

pub(crate) use general::GeneralOptimizer;
pub(crate) use redundant_range::RangeOptimizer;

use super::AcirTransformationMap;

/// Applies [`ProofSystemCompiler`][crate::ProofSystemCompiler] independent optimizations to a [`Circuit`].
pub fn optimize(acir: Circuit) -> (Circuit, AcirTransformationMap) {
// General optimizer pass
let mut opcodes: Vec<Opcode> = Vec::new();
for opcode in acir.opcodes {
match opcode {
Opcode::Arithmetic(arith_expr) => {
opcodes.push(Opcode::Arithmetic(GeneralOptimizer::optimize(arith_expr)));
}
other_opcode => opcodes.push(other_opcode),
};
}
let acir = Circuit { opcodes, ..acir };

// Track original acir opcode positions throughout the transformation passes of the compilation
// by applying the modifications done to the circuit opcodes and also to the opcode_positions (delete and insert)
let acir_opcode_positions = acir.opcodes.iter().enumerate().map(|(i, _)| i).collect();

// Range optimization pass
let range_optimizer = RangeOptimizer::new(acir);
let (acir, acir_opcode_positions) =
range_optimizer.replace_redundant_ranges(acir_opcode_positions);

let transformation_map = AcirTransformationMap { acir_opcode_positions };

(acir, transformation_map)
}
Loading

0 comments on commit 5865d1a

Please sign in to comment.