Skip to content

Commit

Permalink
chore: add bigint solver in ACVM and add a unit test for bigints in N…
Browse files Browse the repository at this point in the history
…oir (AztecProtocol#4415)

The PR enable bigints in ACVM, i.e you can now compile and execute Noir
programs using bigint opcodes.
This is demonstrated in the added unit test.

---------

Co-authored-by: kevaundray <kevtheappdev@gmail.com>
  • Loading branch information
guipublic and kevaundray authored Feb 5, 2024
1 parent f09e8fc commit e4a2fe9
Show file tree
Hide file tree
Showing 9 changed files with 213 additions and 18 deletions.
120 changes: 120 additions & 0 deletions noir/acvm-repo/acvm/src/pwg/blackbox/bigint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use std::collections::HashMap;

use acir::{
circuit::opcodes::FunctionInput,
native_types::{Witness, WitnessMap},
BlackBoxFunc, FieldElement,
};

use num_bigint::BigUint;

use crate::pwg::OpcodeResolutionError;

/// Resolve BigInt opcodes by storing BigInt values (and their moduli) by their ID in a HashMap:
/// - When it encounters a bigint operation opcode, it performs the operation on the stored values
/// and store the result using the provided ID.
/// - When it gets a to_bytes opcode, it simply looks up the value and resolves the output witness accordingly.
#[derive(Default)]
pub(crate) struct BigIntSolver {
bigint_id_to_value: HashMap<u32, BigUint>,
bigint_id_to_modulus: HashMap<u32, BigUint>,
}

impl BigIntSolver {
pub(crate) fn get_bigint(
&self,
id: u32,
func: BlackBoxFunc,
) -> Result<BigUint, OpcodeResolutionError> {
self.bigint_id_to_value
.get(&id)
.ok_or(OpcodeResolutionError::BlackBoxFunctionFailed(
func,
format!("could not find bigint of id {id}"),
))
.cloned()
}

pub(crate) fn get_modulus(
&self,
id: u32,
func: BlackBoxFunc,
) -> Result<BigUint, OpcodeResolutionError> {
self.bigint_id_to_modulus
.get(&id)
.ok_or(OpcodeResolutionError::BlackBoxFunctionFailed(
func,
format!("could not find bigint of id {id}"),
))
.cloned()
}
pub(crate) fn bigint_from_bytes(
&mut self,
inputs: &[FunctionInput],
modulus: &[u8],
output: u32,
initial_witness: &mut WitnessMap,
) -> Result<(), OpcodeResolutionError> {
let bytes = inputs
.iter()
.map(|input| initial_witness.get(&input.witness).unwrap().to_u128() as u8)
.collect::<Vec<u8>>();
let bigint = BigUint::from_bytes_le(&bytes);
self.bigint_id_to_value.insert(output, bigint);
let modulus = BigUint::from_bytes_le(modulus);
self.bigint_id_to_modulus.insert(output, modulus);
Ok(())
}

pub(crate) fn bigint_to_bytes(
&self,
input: u32,
outputs: &Vec<Witness>,
initial_witness: &mut WitnessMap,
) -> Result<(), OpcodeResolutionError> {
let bigint = self.get_bigint(input, BlackBoxFunc::BigIntToLeBytes)?;

let mut bytes = bigint.to_bytes_le();
while bytes.len() < outputs.len() {
bytes.push(0);
}
bytes.iter().zip(outputs.iter()).for_each(|(byte, output)| {
initial_witness.insert(*output, FieldElement::from(*byte as u128));
});
Ok(())
}

pub(crate) fn bigint_op(
&mut self,
lhs: u32,
rhs: u32,
output: u32,
func: BlackBoxFunc,
) -> Result<(), OpcodeResolutionError> {
let modulus = self.get_modulus(lhs, func)?;
let lhs = self.get_bigint(lhs, func)?;
let rhs = self.get_bigint(rhs, func)?;
let mut result = match func {
BlackBoxFunc::BigIntAdd => lhs + rhs,
BlackBoxFunc::BigIntNeg => {
if lhs >= rhs {
&lhs - &rhs
} else {
&lhs + &modulus - &rhs
}
}
BlackBoxFunc::BigIntMul => lhs * rhs,
BlackBoxFunc::BigIntDiv => {
lhs * rhs.modpow(&(&modulus - BigUint::from(1_u32)), &modulus)
} //TODO ensure that modulus is prime
_ => unreachable!("ICE - bigint_op must be called for an operation"),
};
if result > modulus {
let q = &result / &modulus;
result -= q * &modulus;
}
self.bigint_id_to_value.insert(output, result);
self.bigint_id_to_modulus.insert(output, modulus);
Ok(())
}
}
22 changes: 15 additions & 7 deletions noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ use acir::{
};
use acvm_blackbox_solver::{blake2s, blake3, keccak256, keccakf1600, sha256};

use self::pedersen::pedersen_hash;
use self::{bigint::BigIntSolver, pedersen::pedersen_hash};

use super::{insert_value, OpcodeNotSolvable, OpcodeResolutionError};
use crate::{pwg::witness_to_value, BlackBoxFunctionSolver};

pub(crate) mod bigint;
mod fixed_base_scalar_mul;
mod hash;
mod logic;
Expand Down Expand Up @@ -53,6 +54,7 @@ pub(crate) fn solve(
backend: &impl BlackBoxFunctionSolver,
initial_witness: &mut WitnessMap,
bb_func: &BlackBoxFuncCall,
bigint_solver: &mut BigIntSolver,
) -> Result<(), OpcodeResolutionError> {
let inputs = bb_func.get_inputs_vec();
if !contains_all_inputs(initial_witness, &inputs) {
Expand Down Expand Up @@ -190,12 +192,18 @@ pub(crate) fn solve(
}
// Recursive aggregation will be entirely handled by the backend and is not solved by the ACVM
BlackBoxFuncCall::RecursiveAggregation { .. } => Ok(()),
BlackBoxFuncCall::BigIntAdd { .. } => todo!(),
BlackBoxFuncCall::BigIntNeg { .. } => todo!(),
BlackBoxFuncCall::BigIntMul { .. } => todo!(),
BlackBoxFuncCall::BigIntDiv { .. } => todo!(),
BlackBoxFuncCall::BigIntFromLeBytes { .. } => todo!(),
BlackBoxFuncCall::BigIntToLeBytes { .. } => todo!(),
BlackBoxFuncCall::BigIntAdd { lhs, rhs, output }
| BlackBoxFuncCall::BigIntNeg { lhs, rhs, output }
| BlackBoxFuncCall::BigIntMul { lhs, rhs, output }
| BlackBoxFuncCall::BigIntDiv { lhs, rhs, output } => {
bigint_solver.bigint_op(*lhs, *rhs, *output, bb_func.get_black_box_func())
}
BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus, output } => {
bigint_solver.bigint_from_bytes(inputs, modulus, *output, initial_witness)
}
BlackBoxFuncCall::BigIntToLeBytes { input, outputs } => {
bigint_solver.bigint_to_bytes(*input, outputs, initial_witness)
}
BlackBoxFuncCall::Poseidon2Permutation { .. } => todo!(),
BlackBoxFuncCall::Sha256Compression { .. } => todo!(),
}
Expand Down
17 changes: 13 additions & 4 deletions noir/acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use acir::{
};
use acvm_blackbox_solver::BlackBoxResolutionError;

use self::{arithmetic::ExpressionSolver, directives::solve_directives, memory_op::MemoryOpSolver};
use self::{
arithmetic::ExpressionSolver, blackbox::bigint::BigIntSolver, directives::solve_directives,
memory_op::MemoryOpSolver,
};
use crate::BlackBoxFunctionSolver;

use thiserror::Error;
Expand Down Expand Up @@ -132,6 +135,8 @@ pub struct ACVM<'a, B: BlackBoxFunctionSolver> {
/// Stores the solver for memory operations acting on blocks of memory disambiguated by [block][`BlockId`].
block_solvers: HashMap<BlockId, MemoryOpSolver>,

bigint_solver: BigIntSolver,

/// A list of opcodes which are to be executed by the ACVM.
opcodes: &'a [Opcode],
/// Index of the next opcode to be executed.
Expand All @@ -149,6 +154,7 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {
status,
backend,
block_solvers: HashMap::default(),
bigint_solver: BigIntSolver::default(),
opcodes,
instruction_pointer: 0,
witness_map: initial_witness,
Expand Down Expand Up @@ -254,9 +260,12 @@ impl<'a, B: BlackBoxFunctionSolver> ACVM<'a, B> {

let resolution = match opcode {
Opcode::AssertZero(expr) => ExpressionSolver::solve(&mut self.witness_map, expr),
Opcode::BlackBoxFuncCall(bb_func) => {
blackbox::solve(self.backend, &mut self.witness_map, bb_func)
}
Opcode::BlackBoxFuncCall(bb_func) => blackbox::solve(
self.backend,
&mut self.witness_map,
bb_func,
&mut self.bigint_solver,
),
Opcode::Directive(directive) => solve_directives(&mut self.witness_map, directive),
Opcode::MemoryInit { block_id, init } => {
let solver = self.block_solvers.entry(*block_id).or_default();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,8 @@ impl AcirContext {
for i in const_inputs {
field_inputs.push(i?);
}
let modulus = self.big_int_ctx.modulus(field_inputs[0]);
let bigint = self.big_int_ctx.get(field_inputs[0]);
let modulus = self.big_int_ctx.modulus(bigint.modulus_id());
let bytes_len = ((modulus - BigUint::from(1_u32)).bits() - 1) / 8 + 1;
output_count = bytes_len as usize;
(field_inputs, vec![FieldElement::from(bytes_len as u128)])
Expand Down
38 changes: 33 additions & 5 deletions noir/noir_stdlib/src/bigint.nr
Original file line number Diff line number Diff line change
@@ -1,27 +1,55 @@
use crate::ops::{Add, Sub, Mul, Div, Rem,};


global bn254_fq = [0x47, 0xFD, 0x7C, 0xD8, 0x16, 0x8C, 0x20, 0x3C, 0x8d, 0xca, 0x71, 0x68, 0x91, 0x6a, 0x81, 0x97,
0x5d, 0x58, 0x81, 0x81, 0xb6, 0x45, 0x50, 0xb8, 0x29, 0xa0, 0x31, 0xe1, 0x72, 0x4e, 0x64, 0x30];
global bn254_fr = [0x01, 0x00, 0x00, 0x00, 0x3F, 0x59, 0x1F, 0x43, 0x09, 0x97, 0xB9, 0x79, 0x48, 0xE8, 0x33, 0x28,
0x5D, 0x58, 0x81, 0x81, 0xB6, 0x45, 0x50, 0xB8, 0x29, 0xA0, 0x31, 0xE1, 0x72, 0x4E, 0x64, 0x30];
global secpk1_fr = [0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, 0xBA,
0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
global secpk1_fq = [0x2F, 0xFC, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
global secpr1_fq = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, 0xFF, 0xFF, 0xFF];
global secpr1_fr = [0x51, 0x25, 0x63, 0xFC, 0xC2, 0xCA, 0xB9, 0xF3, 0x84, 0x9E, 0x17, 0xA7, 0xAD, 0xFA, 0xE6, 0xBC,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, 0x00,0xFF, 0xFF, 0xFF, 0xFF];


struct BigInt {
pointer: u32,
modulus: u32,
}

impl BigInt {
#[builtin(bigint_add)]
pub fn bigint_add(self, other: BigInt) -> BigInt {
fn bigint_add(self, other: BigInt) -> BigInt {
}
#[builtin(bigint_neg)]
pub fn bigint_neg(self, other: BigInt) -> BigInt {
fn bigint_neg(self, other: BigInt) -> BigInt {
}
#[builtin(bigint_mul)]
pub fn bigint_mul(self, other: BigInt) -> BigInt {
fn bigint_mul(self, other: BigInt) -> BigInt {
}
#[builtin(bigint_div)]
pub fn bigint_div(self, other: BigInt) -> BigInt {
fn bigint_div(self, other: BigInt) -> BigInt {
}
#[builtin(bigint_from_le_bytes)]
pub fn from_le_bytes(bytes: [u8], modulus: [u8]) -> BigInt {}
fn from_le_bytes(bytes: [u8], modulus: [u8]) -> BigInt {}
#[builtin(bigint_to_le_bytes)]
pub fn to_le_bytes(self) -> [u8] {}

pub fn bn254_fr_from_le_bytes(bytes: [u8]) -> BigInt {
BigInt::from_le_bytes(bytes, bn254_fr)
}
pub fn bn254_fq_from_le_bytes(bytes: [u8]) -> BigInt {
BigInt::from_le_bytes(bytes, bn254_fq)
}
pub fn secpk1_fq_from_le_bytes(bytes: [u8]) -> BigInt {
BigInt::from_le_bytes(bytes, secpk1_fq)
}
pub fn secpk1_fr_from_le_bytes(bytes: [u8]) -> BigInt {
BigInt::from_le_bytes(bytes, secpk1_fr)
}
}

impl Add for BigInt {
Expand Down
2 changes: 1 addition & 1 deletion noir/noir_stdlib/src/lib.nr
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ mod ops;
mod default;
mod prelude;
mod uint128;
// mod bigint;
mod bigint;

// Oracle calls are required to be wrapped in an unconstrained function
// Thus, the only argument to the `println` oracle is expected to always be an ident
Expand Down
6 changes: 6 additions & 0 deletions noir/test_programs/execution_success/bigint/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[package]
name = "bigint"
type = "bin"
authors = [""]

[dependencies]
2 changes: 2 additions & 0 deletions noir/test_programs/execution_success/bigint/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = [34,3,5,8,4]
y = [44,7,1,8,8]
21 changes: 21 additions & 0 deletions noir/test_programs/execution_success/bigint/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use dep::std::bigint;

fn main(mut x: [u8;5], y: [u8;5]) {
let a = bigint::BigInt::secpk1_fq_from_le_bytes([x[0],x[1],x[2],x[3],x[4]]);
let b = bigint::BigInt::secpk1_fq_from_le_bytes([y[0],y[1],y[2],y[3],y[4]]);

let a_bytes = a.to_le_bytes();
let b_bytes = b.to_le_bytes();
for i in 0..5 {
assert(a_bytes[i] == x[i]);
assert(b_bytes[i] == y[i]);
}

let d = a*b - b;
let d_bytes = d.to_le_bytes();
let d1 = bigint::BigInt::secpk1_fq_from_le_bytes(597243850900842442924.to_le_bytes(10));
let d1_bytes = d1.to_le_bytes();
for i in 0..32 {
assert(d_bytes[i] == d1_bytes[i]);
}
}

0 comments on commit e4a2fe9

Please sign in to comment.