Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ensure that generated ACIR is solvable #6415

Merged
merged 14 commits into from
Nov 1, 2024
2 changes: 2 additions & 0 deletions acvm-repo/acvm/src/compiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ use acir::{

// The various passes that we can use over ACIR
mod optimizers;
mod simulator;
mod transformers;

pub use optimizers::optimize;
use optimizers::optimize_internal;
pub use simulator::CircuitSimulator;
use transformers::transform_internal;
pub use transformers::{transform, MIN_EXPRESSION_WIDTH};

Expand Down
27 changes: 11 additions & 16 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use acir::{
AcirField,
};

use crate::compiler::CircuitSimulator;

pub(crate) struct MergeExpressionsOptimizer {
resolved_blocks: HashMap<BlockId, BTreeSet<Witness>>,
}
Expand Down Expand Up @@ -76,7 +78,7 @@ impl MergeExpressionsOptimizer {
modified_gates.insert(b, Opcode::AssertZero(expr));
to_keep = false;
// Update the 'used_witness' map to account for the merge.
for w2 in Self::expr_wit(&expr_define) {
for w2 in CircuitSimulator::expr_wit(&expr_define) {
if !circuit_inputs.contains(&w2) {
let mut v = used_witness[&w2].clone();
v.insert(b);
Expand Down Expand Up @@ -104,22 +106,15 @@ impl MergeExpressionsOptimizer {
(new_circuit, new_acir_opcode_positions)
}

fn expr_wit<F>(expr: &Expression<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
result.extend(expr.mul_terms.iter().flat_map(|i| vec![i.1, i.2]));
result.extend(expr.linear_combinations.iter().map(|i| i.1));
result
}

fn brillig_input_wit<F>(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
match input {
BrilligInputs::Single(expr) => {
result.extend(Self::expr_wit(expr));
result.extend(CircuitSimulator::expr_wit(expr));
}
BrilligInputs::Array(exprs) => {
for expr in exprs {
result.extend(Self::expr_wit(expr));
result.extend(CircuitSimulator::expr_wit(expr));
}
}
BrilligInputs::MemoryArray(block_id) => {
Expand All @@ -134,16 +129,16 @@ impl MergeExpressionsOptimizer {
fn witness_inputs<F: AcirField>(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
let mut witnesses = BTreeSet::new();
match opcode {
Opcode::AssertZero(expr) => Self::expr_wit(expr),
Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr),
Opcode::BlackBoxFuncCall(bb_func) => bb_func.get_input_witnesses(),
Opcode::Directive(Directive::ToLeRadix { a, .. }) => Self::expr_wit(a),
Opcode::Directive(Directive::ToLeRadix { a, .. }) => CircuitSimulator::expr_wit(a),
Opcode::MemoryOp { block_id: _, op, predicate } => {
//index et value, et predicate
let mut witnesses = BTreeSet::new();
witnesses.extend(Self::expr_wit(&op.index));
witnesses.extend(Self::expr_wit(&op.value));
witnesses.extend(CircuitSimulator::expr_wit(&op.index));
witnesses.extend(CircuitSimulator::expr_wit(&op.value));
if let Some(p) = predicate {
witnesses.extend(Self::expr_wit(p));
witnesses.extend(CircuitSimulator::expr_wit(p));
}
witnesses
}
Expand All @@ -162,7 +157,7 @@ impl MergeExpressionsOptimizer {
witnesses.insert(*i);
}
if let Some(p) = predicate {
witnesses.extend(Self::expr_wit(p));
witnesses.extend(CircuitSimulator::expr_wit(p));
}
witnesses
}
Expand Down
221 changes: 221 additions & 0 deletions acvm-repo/acvm/src/compiler/simulator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
use acir::{
circuit::{
brillig::{BrilligInputs, BrilligOutputs},
directives::Directive,
opcodes::{BlockId, FunctionInput},
Circuit, Opcode,
},
native_types::{Expression, Witness},
AcirField,
};
use std::collections::{BTreeSet, HashMap, HashSet};

#[derive(PartialEq)]
enum BlockStatus {
Initialized,
Used,
}

/// Simulate a symbolic solve for a circuit
#[derive(Default)]
pub struct CircuitSimulator {
/// Track the witnesses that can be solved
solvable_witness: HashSet<Witness>,

/// Tells whether a Memory Block is:
/// - Not initialized if not in the map
/// - Initialized if its status is Initialized in the Map
/// - Used, indicating that the block cannot be written anymore.
resolved_blocks: HashMap<BlockId, BlockStatus>,
}

impl CircuitSimulator {
// Simulate a symbolic solve for a circuit by keeping track of the witnesses that can be solved.
// Returns false if the circuit cannot be solved
pub fn check_circuit<F: AcirField>(&mut self, circuit: &Circuit<F>) -> bool {
let circuit_inputs = circuit.circuit_arguments();
self.solvable_witness.extend(circuit_inputs.iter());
for op in &circuit.opcodes {
if !self.try_solve(op) {
return false;
}
}
true
}

/// Check if the Opcode can be solved, and if yes, add the solved witness to set of solvable witness
fn try_solve<F: AcirField>(&mut self, opcode: &Opcode<F>) -> bool {
let mut unresolved = HashSet::new();
match opcode {
Opcode::AssertZero(expr) => {
for (_, w1, w2) in &expr.mul_terms {
if !self.solvable_witness.contains(w1) {
if !self.solvable_witness.contains(w2) {
return false;
}
unresolved.insert(*w1);
}
if !self.solvable_witness.contains(w2) && w1 != w2 {
unresolved.insert(*w2);
}
}
for (_, w) in &expr.linear_combinations {
if !self.solvable_witness.contains(w) {
unresolved.insert(*w);
}
}
if unresolved.len() == 1 {
self.mark_solvable(*unresolved.iter().next().unwrap());
return true;
}
unresolved.is_empty()
}
Opcode::BlackBoxFuncCall(black_box_func_call) => {
let inputs = black_box_func_call.get_inputs_vec();
for input in inputs {
if !self.can_solve_function_input(&input) {
return false;
}
}
let outputs = black_box_func_call.get_outputs_vec();
for output in outputs {
self.mark_solvable(output);
}
true
}
Opcode::Directive(directive) => match directive {
Directive::ToLeRadix { a, b, .. } => {
if !self.can_solve_expression(a) {
return false;
}
for w in b {
self.mark_solvable(*w);
}
true
}
},
Opcode::MemoryOp { block_id, op, predicate } => {
if !self.can_solve_expression(&op.index) {
return false;
}
if let Some(predicate) = predicate {
if !self.can_solve_expression(predicate) {
return false;
}
}
if op.operation.is_zero() {
let w = op.value.to_witness().unwrap();
self.mark_solvable(w);
true
} else {
if let Some(BlockStatus::Used) = self.resolved_blocks.get(block_id) {
// Writing after having used the block should not be allowed
return false;
}
self.try_solve(&Opcode::AssertZero(op.value.clone()))
}
}
Opcode::MemoryInit { block_id, init, .. } => {
for w in init {
if !self.solvable_witness.contains(w) {
return false;
}
}
self.resolved_blocks.insert(*block_id, BlockStatus::Initialized);
true
}
Opcode::BrilligCall { id: _, inputs, outputs, predicate } => {
for input in inputs {
if !self.can_solve_brillig_input(input) {
return false;
}
}
if let Some(predicate) = predicate {
if !self.can_solve_expression(predicate) {
return false;
}
}
for output in outputs {
match output {
BrilligOutputs::Simple(w) => self.mark_solvable(*w),
BrilligOutputs::Array(arr) => {
for w in arr {
self.mark_solvable(*w);
}
}
}
}
true
}
Opcode::Call { id: _, inputs, outputs, predicate } => {
for w in inputs {
if !self.solvable_witness.contains(w) {
return false;
}
}
if let Some(predicate) = predicate {
if !self.can_solve_expression(predicate) {
return false;
}
}
for w in outputs {
self.mark_solvable(*w);
}
true
}
}
}

/// Adds the witness to set of solvable witness
pub(crate) fn mark_solvable(&mut self, witness: Witness) {
self.solvable_witness.insert(witness);
}

pub fn can_solve_function_input<F: AcirField>(&self, input: &FunctionInput<F>) -> bool {
if !input.is_constant() {
return self.solvable_witness.contains(&input.to_witness());
}
true
}
fn can_solve_expression<F>(&self, expr: &Expression<F>) -> bool {
for w in Self::expr_wit(expr) {
if !self.solvable_witness.contains(&w) {
return false;
}
}
true
}
fn can_solve_brillig_input<F>(&mut self, input: &BrilligInputs<F>) -> bool {
match input {
BrilligInputs::Single(expr) => self.can_solve_expression(expr),
BrilligInputs::Array(exprs) => {
for expr in exprs {
if !self.can_solve_expression(expr) {
return false;
}
}
true
}

BrilligInputs::MemoryArray(block_id) => match self.resolved_blocks.entry(*block_id) {
std::collections::hash_map::Entry::Vacant(_) => false,
std::collections::hash_map::Entry::Occupied(entry)
if *entry.get() == BlockStatus::Used =>
{
true
}
std::collections::hash_map::Entry::Occupied(mut entry) => {
entry.insert(BlockStatus::Used);
true
}
},
}
}

pub(crate) fn expr_wit<F>(expr: &Expression<F>) -> BTreeSet<Witness> {
let mut result = BTreeSet::new();
result.extend(expr.mul_terms.iter().flat_map(|i| vec![i.1, i.2]));
result.extend(expr.linear_combinations.iter().map(|i| i.1));
result
}
}
10 changes: 1 addition & 9 deletions compiler/noirc_evaluator/src/ssa/ir/instruction/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,7 @@ impl Binary {
let zero = dfg.make_constant(FieldElement::zero(), operand_type);
return SimplifyResult::SimplifiedTo(zero);
}

// `two_pow_rhs` is limited to be at most `2 ^ {operand_bitsize - 1}` so it fits in `operand_type`.
let two_pow_rhs = FieldElement::from(2u128).pow(&rhs_const);
let two_pow_rhs = dfg.make_constant(two_pow_rhs, operand_type);
return SimplifyResult::SimplifiedToInstruction(Instruction::binary(
BinaryOp::Div,
self.lhs,
two_pow_rhs,
));
return SimplifyResult::None;
}
}
};
Expand Down
33 changes: 23 additions & 10 deletions compiler/noirc_evaluator/src/ssa/opt/remove_bit_shifts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ impl Context<'_> {
}

/// Insert ssa instructions which computes lhs >> rhs by doing lhs/2^rhs
/// For negative signed integers, we do the division on the 1-complement representation of lhs,
/// before converting back the result to the 2-complement representation.
pub(crate) fn insert_shift_right(
&mut self,
lhs: ValueId,
Expand All @@ -153,16 +155,27 @@ impl Context<'_> {
) -> ValueId {
let lhs_typ = self.function.dfg.type_of_value(lhs);
let base = self.field_constant(FieldElement::from(2_u128));
// we can safely cast to unsigned because overflow_checks prevent bit-shift with a negative value
let rhs_unsigned = self.insert_cast(rhs, Type::unsigned(bit_size));
let pow = self.pow(base, rhs_unsigned);
// We need at least one more bit for the case where rhs == bit_size
let div_type = Type::unsigned(bit_size + 1);
let casted_lhs = self.insert_cast(lhs, div_type.clone());
let casted_pow = self.insert_cast(pow, div_type);
let div_result = self.insert_binary(casted_lhs, BinaryOp::Div, casted_pow);
// We have to cast back to the original type
self.insert_cast(div_result, lhs_typ)
let pow = self.pow(base, rhs);
if lhs_typ.is_unsigned() {
// unsigned right bit shift is just a normal division
self.insert_binary(lhs, BinaryOp::Div, pow)
} else {
// Get the sign of the operand; positive signed operand will just do a division as well
let zero = self.numeric_constant(FieldElement::zero(), Type::signed(bit_size));
let lhs_sign = self.insert_binary(lhs, BinaryOp::Lt, zero);
let lhs_sign_as_field = self.insert_cast(lhs_sign, Type::field());
let lhs_as_field = self.insert_cast(lhs, Type::field());
// For negative numbers, convert to 1-complement using wrapping addition of a + 1
let one_complement = self.insert_binary(lhs_sign_as_field, BinaryOp::Add, lhs_as_field);
let one_complement = self.insert_truncate(one_complement, bit_size, bit_size + 1);
let one_complement = self.insert_cast(one_complement, Type::signed(bit_size));
// Performs the division on the 1-complement (or the operand if positive)
let shifted_complement = self.insert_binary(one_complement, BinaryOp::Div, pow);
// Convert back to 2-complement representation if operand is negative
let lhs_sign_as_int = self.insert_cast(lhs_sign, lhs_typ);
let shifted = self.insert_binary(shifted_complement, BinaryOp::Sub, lhs_sign_as_int);
self.insert_truncate(shifted, bit_size, bit_size + 1)
}
}

/// Computes lhs^rhs via square&multiply, using the bits decomposition of rhs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ fn main(x: u64) {
assert(x << 63 == 0);

assert_eq((1 as u64) << 32, 0x0100000000);

//regression for 6201
let a: i16 = -769;
assert_eq(a >> 3, -97);
}

fn regression_2250() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
x = 64
y = 1
y = 1
z = "-769"
Loading
Loading