Skip to content

Commit

Permalink
fix: issue 4682 and add solver for unconstrained bigintegers (#4729)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Managing bigintegers as ID makes it difficult to handle them in
conditionals and also when passing them around unconstrained functions.
As a matter of fact these two use cases do not work, the first one being
documented in issue #4682, while the second one was never implemented.

## Summary\*

In order to solve this two issues, I modified the bigint representation
so that it stores its byte array instead of an id, allowing it to work
directly with conditional and unconstrained functions.


## Additional Context

For simplicity, I also made the byte arrays of hardcoded length 32.
Since the moduli we support are of this length and since custom
bigintegers will be implemented differently, I do not expect this length
to change.
The barrentenberg interface has not changed and we still use id for it.
The Id is constructed when required by doing a call to from_bytes.

## Documentation\*
The traits have not changed, I have simply added a from_byte_32 method
which takes a 32-bytes array as input.

Check one:
- [] No documentation needed.
- [X] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [X] I have tested the changes locally.
- [X] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: Tom French <tom@tomfren.ch>
  • Loading branch information
guipublic and TomAFrench authored Apr 17, 2024
1 parent 23e1f3b commit e4d33c1
Show file tree
Hide file tree
Showing 16 changed files with 484 additions and 158 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

74 changes: 8 additions & 66 deletions acvm-repo/acvm/src/pwg/blackbox/bigint.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,23 @@
use std::collections::HashMap;

use acir::{
circuit::opcodes::FunctionInput,
native_types::{Witness, WitnessMap},
BlackBoxFunc, FieldElement,
};

use num_bigint::BigUint;
use acvm_blackbox_solver::BigIntSolver;

use crate::pwg::OpcodeResolutionError;

/// Resolve BigInt opcodes by storing BigInt values (and their moduli) by their ID in a HashMap:
/// Resolve BigInt opcodes by storing BigInt values (and their moduli) by their ID in the BigIntSolver
/// - When it encounters a bigint operation opcode, it performs the operation on the stored values
/// and store the result using the provided ID.
/// - When it gets a to_bytes opcode, it simply looks up the value and resolves the output witness accordingly.
#[derive(Default)]
pub(crate) struct BigIntSolver {
bigint_id_to_value: HashMap<u32, BigUint>,
bigint_id_to_modulus: HashMap<u32, BigUint>,
pub(crate) struct AcvmBigIntSolver {
bigint_solver: BigIntSolver,
}

impl BigIntSolver {
pub(crate) fn get_bigint(
&self,
id: u32,
func: BlackBoxFunc,
) -> Result<BigUint, OpcodeResolutionError> {
self.bigint_id_to_value
.get(&id)
.ok_or(OpcodeResolutionError::BlackBoxFunctionFailed(
func,
format!("could not find bigint of id {id}"),
))
.cloned()
}

pub(crate) fn get_modulus(
&self,
id: u32,
func: BlackBoxFunc,
) -> Result<BigUint, OpcodeResolutionError> {
self.bigint_id_to_modulus
.get(&id)
.ok_or(OpcodeResolutionError::BlackBoxFunctionFailed(
func,
format!("could not find bigint of id {id}"),
))
.cloned()
}
impl AcvmBigIntSolver {
pub(crate) fn bigint_from_bytes(
&mut self,
inputs: &[FunctionInput],
Expand All @@ -59,10 +29,7 @@ impl BigIntSolver {
.iter()
.map(|input| initial_witness.get(&input.witness).unwrap().to_u128() as u8)
.collect::<Vec<u8>>();
let bigint = BigUint::from_bytes_le(&bytes);
self.bigint_id_to_value.insert(output, bigint);
let modulus = BigUint::from_bytes_le(modulus);
self.bigint_id_to_modulus.insert(output, modulus);
self.bigint_solver.bigint_from_bytes(&bytes, modulus, output)?;
Ok(())
}

Expand All @@ -72,9 +39,7 @@ impl BigIntSolver {
outputs: &[Witness],
initial_witness: &mut WitnessMap,
) -> Result<(), OpcodeResolutionError> {
let bigint = self.get_bigint(input, BlackBoxFunc::BigIntToLeBytes)?;

let mut bytes = bigint.to_bytes_le();
let mut bytes = self.bigint_solver.bigint_to_bytes(input)?;
while bytes.len() < outputs.len() {
bytes.push(0);
}
Expand All @@ -91,30 +56,7 @@ impl BigIntSolver {
output: u32,
func: BlackBoxFunc,
) -> Result<(), OpcodeResolutionError> {
let modulus = self.get_modulus(lhs, func)?;
let lhs = self.get_bigint(lhs, func)?;
let rhs = self.get_bigint(rhs, func)?;
let mut result = match func {
BlackBoxFunc::BigIntAdd => lhs + rhs,
BlackBoxFunc::BigIntSub => {
if lhs >= rhs {
&lhs - &rhs
} else {
&lhs + &modulus - &rhs
}
}
BlackBoxFunc::BigIntMul => lhs * rhs,
BlackBoxFunc::BigIntDiv => {
lhs * rhs.modpow(&(&modulus - BigUint::from(2_u32)), &modulus)
} //TODO ensure that modulus is prime
_ => unreachable!("ICE - bigint_op must be called for an operation"),
};
if result > modulus {
let q = &result / &modulus;
result -= q * &modulus;
}
self.bigint_id_to_value.insert(output, result);
self.bigint_id_to_modulus.insert(output, modulus);
self.bigint_solver.bigint_op(lhs, rhs, output, func)?;
Ok(())
}
}
4 changes: 2 additions & 2 deletions acvm-repo/acvm/src/pwg/blackbox/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use acir::{
use acvm_blackbox_solver::{blake2s, blake3, keccak256, keccakf1600, sha256};

use self::{
bigint::BigIntSolver, hash::solve_poseidon2_permutation_opcode, pedersen::pedersen_hash,
bigint::AcvmBigIntSolver, hash::solve_poseidon2_permutation_opcode, pedersen::pedersen_hash,
};

use super::{insert_value, OpcodeNotSolvable, OpcodeResolutionError};
Expand Down Expand Up @@ -56,7 +56,7 @@ pub(crate) fn solve(
backend: &impl BlackBoxFunctionSolver,
initial_witness: &mut WitnessMap,
bb_func: &BlackBoxFuncCall,
bigint_solver: &mut BigIntSolver,
bigint_solver: &mut AcvmBigIntSolver,
) -> Result<(), OpcodeResolutionError> {
let inputs = bb_func.get_inputs_vec();
if !contains_all_inputs(initial_witness, &inputs) {
Expand Down
6 changes: 3 additions & 3 deletions acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use acir::{
use acvm_blackbox_solver::BlackBoxResolutionError;

use self::{
arithmetic::ExpressionSolver, blackbox::bigint::BigIntSolver, directives::solve_directives,
arithmetic::ExpressionSolver, blackbox::bigint::AcvmBigIntSolver, directives::solve_directives,
memory_op::MemoryOpSolver,
};
use crate::BlackBoxFunctionSolver;
Expand Down Expand Up @@ -148,7 +148,7 @@ pub struct ACVM<'a, B: BlackBoxFunctionSolver> {
/// Stores the solver for memory operations acting on blocks of memory disambiguated by [block][`BlockId`].
block_solvers: HashMap<BlockId, MemoryOpSolver>,

bigint_solver: BigIntSolver,
bigint_solver: AcvmBigIntSolver,

/// A list of opcodes which are to be executed by the ACVM.
opcodes: &'a [Opcode],
Expand All @@ -174,7 +174,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
status,
backend,
block_solvers: HashMap::default(),
bigint_solver: BigIntSolver::default(),
bigint_solver: AcvmBigIntSolver::default(),
opcodes,
instruction_pointer: 0,
witness_map: initial_witness,
Expand Down
1 change: 1 addition & 0 deletions acvm-repo/blackbox_solver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ repository.workspace = true
[dependencies]
acir.workspace = true
thiserror.workspace = true
num-bigint = "0.4"

blake2 = "0.10.6"
blake3 = "1.5.0"
Expand Down
99 changes: 99 additions & 0 deletions acvm-repo/blackbox_solver/src/bigint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use std::collections::HashMap;

use acir::BlackBoxFunc;

use num_bigint::BigUint;

use crate::BlackBoxResolutionError;

/// Resolve BigInt opcodes by storing BigInt values (and their moduli) by their ID in a HashMap:
/// - When it encounters a bigint operation opcode, it performs the operation on the stored values
/// and store the result using the provided ID.
/// - When it gets a to_bytes opcode, it simply looks up the value and resolves the output witness accordingly.
#[derive(Default, Debug, Clone, PartialEq, Eq)]

pub struct BigIntSolver {
bigint_id_to_value: HashMap<u32, BigUint>,
bigint_id_to_modulus: HashMap<u32, BigUint>,
}

impl BigIntSolver {
pub(crate) fn get_bigint(
&self,
id: u32,
func: BlackBoxFunc,
) -> Result<BigUint, BlackBoxResolutionError> {
self.bigint_id_to_value
.get(&id)
.ok_or(BlackBoxResolutionError::Failed(
func,
format!("could not find bigint of id {id}"),
))
.cloned()
}

pub(crate) fn get_modulus(
&self,
id: u32,
func: BlackBoxFunc,
) -> Result<BigUint, BlackBoxResolutionError> {
self.bigint_id_to_modulus
.get(&id)
.ok_or(BlackBoxResolutionError::Failed(
func,
format!("could not find bigint of id {id}"),
))
.cloned()
}
pub fn bigint_from_bytes(
&mut self,
inputs: &[u8],
modulus: &[u8],
output: u32,
) -> Result<(), BlackBoxResolutionError> {
let bigint = BigUint::from_bytes_le(inputs);
self.bigint_id_to_value.insert(output, bigint);
let modulus = BigUint::from_bytes_le(modulus);
self.bigint_id_to_modulus.insert(output, modulus);
Ok(())
}

pub fn bigint_to_bytes(&self, input: u32) -> Result<Vec<u8>, BlackBoxResolutionError> {
let bigint = self.get_bigint(input, BlackBoxFunc::BigIntToLeBytes)?;
Ok(bigint.to_bytes_le())
}

pub fn bigint_op(
&mut self,
lhs: u32,
rhs: u32,
output: u32,
func: BlackBoxFunc,
) -> Result<(), BlackBoxResolutionError> {
let modulus = self.get_modulus(lhs, func)?;
let lhs = self.get_bigint(lhs, func)?;
let rhs = self.get_bigint(rhs, func)?;
let mut result = match func {
BlackBoxFunc::BigIntAdd => lhs + rhs,
BlackBoxFunc::BigIntSub => {
if lhs >= rhs {
&lhs - &rhs
} else {
&lhs + &modulus - &rhs
}
}
BlackBoxFunc::BigIntMul => lhs * rhs,
BlackBoxFunc::BigIntDiv => {
lhs * rhs.modpow(&(&modulus - BigUint::from(2_u32)), &modulus)
} //TODO ensure that modulus is prime
_ => unreachable!("ICE - bigint_op must be called for an operation"),
};
if result > modulus {
let q = &result / &modulus;
result -= q * &modulus;
}
self.bigint_id_to_value.insert(output, result);
self.bigint_id_to_modulus.insert(output, modulus);
Ok(())
}
}
2 changes: 2 additions & 0 deletions acvm-repo/blackbox_solver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
use acir::BlackBoxFunc;
use thiserror::Error;

mod bigint;
mod curve_specific_solver;
mod ecdsa;
mod hash;

pub use bigint::BigIntSolver;
pub use curve_specific_solver::{BlackBoxFunctionSolver, StubbedBlackBoxSolver};
pub use ecdsa::{ecdsa_secp256k1_verify, ecdsa_secp256r1_verify};
pub use hash::{blake2s, blake3, keccak256, keccakf1600, sha256, sha256compression};
Expand Down
64 changes: 56 additions & 8 deletions acvm-repo/brillig_vm/src/black_box.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use acir::brillig::{BlackBoxOp, HeapArray, HeapVector};
use acir::{BlackBoxFunc, FieldElement};
use acvm_blackbox_solver::BigIntSolver;
use acvm_blackbox_solver::{
blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256, keccakf1600,
sha256, sha256compression, BlackBoxFunctionSolver, BlackBoxResolutionError,
Expand Down Expand Up @@ -34,6 +35,7 @@ pub(crate) fn evaluate_black_box<Solver: BlackBoxFunctionSolver>(
op: &BlackBoxOp,
solver: &Solver,
memory: &mut Memory,
bigint_solver: &mut BigIntSolver,
) -> Result<(), BlackBoxResolutionError> {
match op {
BlackBoxOp::Sha256 { message, output } => {
Expand Down Expand Up @@ -177,12 +179,57 @@ pub(crate) fn evaluate_black_box<Solver: BlackBoxFunctionSolver>(
memory.write(*output, hash.into());
Ok(())
}
BlackBoxOp::BigIntAdd { .. } => todo!(),
BlackBoxOp::BigIntSub { .. } => todo!(),
BlackBoxOp::BigIntMul { .. } => todo!(),
BlackBoxOp::BigIntDiv { .. } => todo!(),
BlackBoxOp::BigIntFromLeBytes { .. } => todo!(),
BlackBoxOp::BigIntToLeBytes { .. } => todo!(),
BlackBoxOp::BigIntAdd { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntAdd)?;
Ok(())
}
BlackBoxOp::BigIntSub { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntSub)?;
Ok(())
}
BlackBoxOp::BigIntMul { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntMul)?;
Ok(())
}
BlackBoxOp::BigIntDiv { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntDiv)?;
Ok(())
}
BlackBoxOp::BigIntFromLeBytes { inputs, modulus, output } => {
let input = read_heap_vector(memory, inputs);
let input: Vec<u8> = input.iter().map(|x| x.try_into().unwrap()).collect();
let modulus = read_heap_vector(memory, modulus);
let modulus: Vec<u8> = modulus.iter().map(|x| x.try_into().unwrap()).collect();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_from_bytes(&input, &modulus, output)?;
Ok(())
}
BlackBoxOp::BigIntToLeBytes { input, output } => {
let input: u32 = memory.read(*input).try_into().unwrap();
let bytes = bigint_solver.bigint_to_bytes(input)?;
let mut values = Vec::new();
for i in 0..32 {
if i < bytes.len() {
values.push(bytes[i].into());
} else {
values.push(0_u8.into());
}
}
memory.write_slice(memory.read_ref(output.pointer), &values);
Ok(())
}
BlackBoxOp::Poseidon2Permutation { message, output, len } => {
let input = read_heap_vector(memory, message);
let input: Vec<FieldElement> = input.iter().map(|x| x.try_into().unwrap()).collect();
Expand Down Expand Up @@ -256,7 +303,7 @@ fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc {
#[cfg(test)]
mod test {
use acir::brillig::{BlackBoxOp, MemoryAddress};
use acvm_blackbox_solver::StubbedBlackBoxSolver;
use acvm_blackbox_solver::{BigIntSolver, StubbedBlackBoxSolver};

use crate::{
black_box::{evaluate_black_box, to_u8_vec, to_value_vec},
Expand All @@ -281,7 +328,8 @@ mod test {
output: HeapArray { pointer: 2.into(), size: 32 },
};

evaluate_black_box(&op, &StubbedBlackBoxSolver, &mut memory).unwrap();
evaluate_black_box(&op, &StubbedBlackBoxSolver, &mut memory, &mut BigIntSolver::default())
.unwrap();

let result = memory.read_slice(MemoryAddress(result_pointer), 32);

Expand Down
Loading

0 comments on commit e4d33c1

Please sign in to comment.