Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ensure that generated ACIR is solvable #6415

Merged
merged 14 commits into from
Nov 1, 2024
2 changes: 2 additions & 0 deletions acvm-repo/acvm/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use acir::{

// The various passes that we can use over ACIR
mod optimizers;
mod simulator;
mod transformers;

pub use optimizers::optimize;
use optimizers::optimize_internal;
pub use simulator::CircuitSimulator;
use transformers::transform_internal;
pub use transformers::{transform, MIN_EXPRESSION_WIDTH};

Expand Down
27 changes: 11 additions & 16 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use acir::{
AcirField,
};

use crate::compiler::CircuitSimulator;

pub(crate) struct MergeExpressionsOptimizer {
resolved_blocks: HashMap<BlockId, BTreeSet<Witness>>,
}
Expand Down Expand Up @@ -76,7 +78,7 @@ impl MergeExpressionsOptimizer {
modified_gates.insert(b, Opcode::AssertZero(expr));
to_keep = false;
// Update the 'used_witness' map to account for the merge.
for w2 in Self::expr_wit(&expr_define) {
for w2 in CircuitSimulator::expr_wit(&expr_define) {
if !circuit_inputs.contains(&w2) {
let mut v = used_witness[&w2].clone();
v.insert(b);
Expand Down Expand Up @@ -104,22 +106,15 @@ impl MergeExpressionsOptimizer {
(new_circuit, new_acir_opcode_positions)
}

fn expr_wit<F>(expr: &Expression<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
result.extend(expr.mul_terms.iter().flat_map(|i| vec![i.1, i.2]));
result.extend(expr.linear_combinations.iter().map(|i| i.1));
result
}

fn brillig_input_wit<F>(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
match input {
BrilligInputs::Single(expr) => {
result.extend(Self::expr_wit(expr));
result.extend(CircuitSimulator::expr_wit(expr));
}
BrilligInputs::Array(exprs) => {
for expr in exprs {
result.extend(Self::expr_wit(expr));
result.extend(CircuitSimulator::expr_wit(expr));
}
}
BrilligInputs::MemoryArray(block_id) => {
Expand All @@ -134,16 +129,16 @@ impl MergeExpressionsOptimizer {
fn witness_inputs<F: AcirField>(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
let mut witnesses = BTreeSet::new();
match opcode {
Opcode::AssertZero(expr) => Self::expr_wit(expr),
Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr),
Opcode::BlackBoxFuncCall(bb_func) => bb_func.get_input_witnesses(),
Opcode::Directive(Directive::ToLeRadix { a, .. }) => Self::expr_wit(a),
Opcode::Directive(Directive::ToLeRadix { a, .. }) => CircuitSimulator::expr_wit(a),
Opcode::MemoryOp { block_id: _, op, predicate } => {
//index et value, et predicate
let mut witnesses = BTreeSet::new();
witnesses.extend(Self::expr_wit(&op.index));
witnesses.extend(Self::expr_wit(&op.value));
witnesses.extend(CircuitSimulator::expr_wit(&op.index));
witnesses.extend(CircuitSimulator::expr_wit(&op.value));
if let Some(p) = predicate {
witnesses.extend(Self::expr_wit(p));
witnesses.extend(CircuitSimulator::expr_wit(p));
}
witnesses
}
Expand All @@ -162,7 +157,7 @@ impl MergeExpressionsOptimizer {
witnesses.insert(*i);
}
if let Some(p) = predicate {
witnesses.extend(Self::expr_wit(p));
witnesses.extend(CircuitSimulator::expr_wit(p));
}
witnesses
}
Expand Down
305 changes: 305 additions & 0 deletions acvm-repo/acvm/src/compiler/simulator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
use acir::{
circuit::{
brillig::{BrilligInputs, BrilligOutputs},
directives::Directive,
opcodes::{BlockId, FunctionInput},
Circuit, Opcode,
},
native_types::{Expression, Witness},
AcirField,
};
use std::collections::{BTreeSet, HashMap, HashSet};

#[derive(PartialEq)]
enum BlockStatus {
Initialized,
Used,
}

/// Simulate a symbolic solve for a circuit
#[derive(Default)]
pub struct CircuitSimulator {
/// Track the witnesses that can be solved
solvable_witness: HashSet<Witness>,

/// Tells whether a Memory Block is:
/// - Not initialized if not in the map
/// - Initialized if its status is Initialized in the Map
/// - Used, indicating that the block cannot be written anymore.
resolved_blocks: HashMap<BlockId, BlockStatus>,
}

impl CircuitSimulator {
/// Simulate a symbolic solve for a circuit by keeping track of the witnesses that can be solved.
/// Returns false if the circuit cannot be solved
#[tracing::instrument(level = "trace", skip_all)]
pub fn check_circuit<F: AcirField>(&mut self, circuit: &Circuit<F>) -> bool {
let circuit_inputs = circuit.circuit_arguments();
self.solvable_witness.extend(circuit_inputs.iter());
for op in &circuit.opcodes {
if !self.try_solve(op) {
return false;
}
}
true
}

/// Check if the Opcode can be solved, and if yes, add the solved witness to set of solvable witness
fn try_solve<F: AcirField>(&mut self, opcode: &Opcode<F>) -> bool {
let mut unresolved = HashSet::new();
match opcode {
Opcode::AssertZero(expr) => {
for (_, w1, w2) in &expr.mul_terms {
if !self.solvable_witness.contains(w1) {
if !self.solvable_witness.contains(w2) {
return false;
}
unresolved.insert(*w1);
}
if !self.solvable_witness.contains(w2) && w1 != w2 {
unresolved.insert(*w2);
}
}
for (_, w) in &expr.linear_combinations {
if !self.solvable_witness.contains(w) {
unresolved.insert(*w);
}
}
if unresolved.len() == 1 {
self.mark_solvable(*unresolved.iter().next().unwrap());
return true;
}
unresolved.is_empty()
}
Opcode::BlackBoxFuncCall(black_box_func_call) => {
let inputs = black_box_func_call.get_inputs_vec();
for input in inputs {
if !self.can_solve_function_input(&input) {
return false;
}
}
let outputs = black_box_func_call.get_outputs_vec();
for output in outputs {
self.mark_solvable(output);
}
true
}
Opcode::Directive(directive) => match directive {
Directive::ToLeRadix { a, b, .. } => {
if !self.can_solve_expression(a) {
return false;
}
for w in b {
self.mark_solvable(*w);
}
true
}
},
Opcode::MemoryOp { block_id, op, predicate } => {
if !self.can_solve_expression(&op.index) {
return false;
}
if let Some(predicate) = predicate {
if !self.can_solve_expression(predicate) {
return false;
}
}
if op.operation.is_zero() {
let w = op.value.to_witness().unwrap();
self.mark_solvable(w);
true
} else {
if let Some(BlockStatus::Used) = self.resolved_blocks.get(block_id) {
// Writing after having used the block should not be allowed
return false;
}
self.try_solve(&Opcode::AssertZero(op.value.clone()))
}
}
Opcode::MemoryInit { block_id, init, .. } => {
for w in init {
if !self.solvable_witness.contains(w) {
return false;
}
}
self.resolved_blocks.insert(*block_id, BlockStatus::Initialized);
true
}
Opcode::BrilligCall { id: _, inputs, outputs, predicate } => {
for input in inputs {
if !self.can_solve_brillig_input(input) {
return false;
}
}
if let Some(predicate) = predicate {
if !self.can_solve_expression(predicate) {
return false;
}
}
for output in outputs {
match output {
BrilligOutputs::Simple(w) => self.mark_solvable(*w),
BrilligOutputs::Array(arr) => {
for w in arr {
self.mark_solvable(*w);
}
}
}
}
true
}
Opcode::Call { id: _, inputs, outputs, predicate } => {
for w in inputs {
if !self.solvable_witness.contains(w) {
return false;
}
}
if let Some(predicate) = predicate {
if !self.can_solve_expression(predicate) {
return false;
}
}
for w in outputs {
self.mark_solvable(*w);
}
true
}
}
}

/// Adds the witness to set of solvable witness
pub(crate) fn mark_solvable(&mut self, witness: Witness) {
self.solvable_witness.insert(witness);
}

pub fn can_solve_function_input<F: AcirField>(&self, input: &FunctionInput<F>) -> bool {
if !input.is_constant() {
return self.solvable_witness.contains(&input.to_witness());
}
true
}
fn can_solve_expression<F>(&self, expr: &Expression<F>) -> bool {
for w in Self::expr_wit(expr) {
if !self.solvable_witness.contains(&w) {
return false;
}
}
true
}
fn can_solve_brillig_input<F>(&mut self, input: &BrilligInputs<F>) -> bool {
match input {
BrilligInputs::Single(expr) => self.can_solve_expression(expr),
BrilligInputs::Array(exprs) => {
for expr in exprs {
if !self.can_solve_expression(expr) {
return false;
}
}
true
}

BrilligInputs::MemoryArray(block_id) => match self.resolved_blocks.entry(*block_id) {
std::collections::hash_map::Entry::Vacant(_) => false,
std::collections::hash_map::Entry::Occupied(entry)
if *entry.get() == BlockStatus::Used =>
{
true
}
std::collections::hash_map::Entry::Occupied(mut entry) => {
entry.insert(BlockStatus::Used);
true
}
},
}
}

pub(crate) fn expr_wit<F>(expr: &Expression<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
result.extend(expr.mul_terms.iter().flat_map(|i| vec![i.1, i.2]));
result.extend(expr.linear_combinations.iter().map(|i| i.1));
result
}
}

#[cfg(test)]
mod tests {
use std::collections::BTreeSet;

use crate::compiler::CircuitSimulator;
use acir::{
acir_field::AcirField,
circuit::{Circuit, ExpressionWidth, Opcode, PublicInputs},
native_types::{Expression, Witness},
FieldElement,
};

fn test_circuit(
opcodes: Vec<Opcode<FieldElement>>,
private_parameters: BTreeSet<Witness>,
public_parameters: PublicInputs,
) -> Circuit<FieldElement> {
Circuit {
current_witness_index: 1,
expression_width: ExpressionWidth::Bounded { width: 4 },
opcodes,
private_parameters,
public_parameters,
return_values: PublicInputs::default(),
assert_messages: Default::default(),
recursive: false,
}
}

#[test]
fn reports_true_for_empty_circuit() {
let empty_circuit = test_circuit(vec![], BTreeSet::default(), PublicInputs::default());

assert!(CircuitSimulator::default().check_circuit(&empty_circuit));
}

#[test]
fn reports_true_for_connected_circuit() {
let connected_circuit = test_circuit(
vec![Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
linear_combinations: vec![
(FieldElement::one(), Witness(1)),
(-FieldElement::one(), Witness(2)),
],
q_c: FieldElement::zero(),
})],
BTreeSet::from([Witness(1)]),
PublicInputs::default(),
);

assert!(CircuitSimulator::default().check_circuit(&connected_circuit));
}

#[test]
fn reports_false_for_disconnected_circuit() {
let disconnected_circuit = test_circuit(
vec![
Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
linear_combinations: vec![
(FieldElement::one(), Witness(1)),
(-FieldElement::one(), Witness(2)),
],
q_c: FieldElement::zero(),
}),
Opcode::AssertZero(Expression {
mul_terms: Vec::new(),
linear_combinations: vec![
(FieldElement::one(), Witness(3)),
(-FieldElement::one(), Witness(4)),
],
q_c: FieldElement::zero(),
}),
],
BTreeSet::from([Witness(1)]),
PublicInputs::default(),
);

assert!(!CircuitSimulator::default().check_circuit(&disconnected_circuit));
}
}
Loading
Loading