Skip to content

Commit

Permalink
support bn254 circuit format
Browse files Browse the repository at this point in the history
  • Loading branch information
sixbigsquare committed Jul 8, 2024
1 parent af04269 commit 9bfee3b
Show file tree
Hide file tree
Showing 9 changed files with 122 additions and 45 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Cargo.lock
/data/Extracted*
/data/circuit8.txt
/data/compiler_out
/data/keccak*
notes.md

# Programming env
Expand Down
3 changes: 3 additions & 0 deletions arith/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,7 @@ pub trait FieldSerde {

/// deserialize bytes into field
fn deserialize_from(buffer: &[u8]) -> Self;

/// deserialize bytes into field following ecc format
fn deserialize_from_ecc_format(bytes: &[u8; 32]) -> Self;
}
4 changes: 4 additions & 0 deletions arith/src/field/bn254.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,8 @@ impl FieldSerde for Fr {
fn deserialize_from(buffer: &[u8]) -> Self {
Fr::from_bytes(buffer[..Fr::SIZE].try_into().unwrap()).unwrap()
}

fn deserialize_from_ecc_format(bytes: &[u8; 32]) -> Self {
Fr::deserialize_from(bytes) // same as deserialize_from
}
}
6 changes: 6 additions & 0 deletions arith/src/field/bn254/vectorized_bn254.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ impl FieldSerde for VectorizedFr {
let v = Fr::deserialize_from(buffer);
Self { v: [v] }
}

#[inline(always)]
fn deserialize_from_ecc_format(bytes: &[u8; 32]) -> Self {
let v = Fr::deserialize_from_ecc_format(bytes);
Self { v: [v] }
}
}

impl Field for VectorizedFr {
Expand Down
9 changes: 9 additions & 0 deletions arith/src/field/m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ impl FieldSerde for M31 {
v = mod_reduce_i32(v);
M31 { v: v as u32 }
}

#[inline(always)]
fn deserialize_from_ecc_format(bytes: &[u8; 32]) -> Self {
let val: [u8; 32] = bytes[0..32].try_into().unwrap();
for (i, v) in val.iter().enumerate().skip(4).take(28) {
assert_eq!(*v, 0, "non-zero byte found in witness at {}'th byte", i);
}
Self::from(u32::from_le_bytes(val[..4].try_into().unwrap()))
}
}

impl Field for M31 {
Expand Down
9 changes: 9 additions & 0 deletions arith/src/field/m31/vectorized_m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ impl FieldSerde for VectorizedM31 {
}
}
}

#[inline(always)]
fn deserialize_from_ecc_format(bytes: &[u8; 32]) -> Self {
let val: [u8; 32] = bytes[0..32].try_into().unwrap();
for (i, v) in val.iter().enumerate().skip(4).take(28) {
assert_eq!(*v, 0, "non-zero byte found in witness at {}'th byte", i);
}
Self::from(u32::from_le_bytes(val[..4].try_into().unwrap()))
}
}

impl Field for VectorizedM31 {
Expand Down
30 changes: 14 additions & 16 deletions src/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arith::{Field, MultiLinearPoly};
use arith::{Field, FieldSerde, MultiLinearPoly};
use ark_std::test_rng;
use std::{cmp::max, collections::HashMap, fs};

Expand Down Expand Up @@ -200,16 +200,7 @@ pub struct Segment<F: Field> {
pub gate_consts: Vec<GateConst<F>>,
}

fn read_f_u32_val(file_bytes: &[u8]) -> u32 {
// hard-coded to read 32 bytes for now
let val: [u8; 32] = file_bytes[0..32].try_into().unwrap();
for (i, v) in val.iter().enumerate().skip(4).take(28) {
assert_eq!(*v, 0, "non-zero byte found in witness at {}'th byte", i);
}
u32::from_le_bytes(val[..4].try_into().unwrap())
}

impl<F: Field> Circuit<F> {
impl<F: Field + FieldSerde> Circuit<F> {
pub fn load_witness_file(&mut self, filename: &str) {
// note that, for data parallel, one should load multiple witnesses into different slot in the vectorized F
let file_bytes = fs::read(filename).unwrap();
Expand All @@ -221,9 +212,10 @@ impl<F: Field> Circuit<F> {
let mut cur = 0;
self.layers[0].input_vals.evals = (0..(1 << self.log_input_size()))
.map(|_| {
let u32_val = read_f_u32_val(&file_bytes[cur..cur + 32]);
let ret =
F::deserialize_from_ecc_format(file_bytes[cur..cur + 32].try_into().unwrap());
cur += 32;
F::from(u32_val)
ret
})
.collect();
}
Expand Down Expand Up @@ -282,7 +274,9 @@ impl<F: Field> Segment<F> {
],
o_id: u64::from_le_bytes(file_bytes[*cur + 16..*cur + 24].try_into().unwrap())
as usize,
coef: F::BaseField::from(read_f_u32_val(&file_bytes[*cur + 24..*cur + 56])),
coef: F::BaseField::deserialize_from_ecc_format(
&file_bytes[*cur + 24..*cur + 56].try_into().unwrap(),
),
};
*cur += 56;
ret.gate_muls.push(gate);
Expand All @@ -297,7 +291,9 @@ impl<F: Field> Segment<F> {
],
o_id: u64::from_le_bytes(file_bytes[*cur + 8..*cur + 16].try_into().unwrap())
as usize,
coef: F::BaseField::from(read_f_u32_val(&file_bytes[*cur + 16..*cur + 48])),
coef: F::BaseField::deserialize_from_ecc_format(
&file_bytes[*cur + 16..*cur + 48].try_into().unwrap(),
),
};
*cur += 48;
ret.gate_adds.push(gate);
Expand All @@ -317,7 +313,9 @@ impl<F: Field> Segment<F> {
let gate = GateConst {
i_ids: [],
o_id: u64::from_le_bytes(file_bytes[*cur..*cur + 8].try_into().unwrap()) as usize,
coef: F::BaseField::from(read_f_u32_val(&file_bytes[*cur + 8..*cur + 40])),
coef: F::BaseField::deserialize_from_ecc_format(
&file_bytes[*cur + 8..*cur + 40].try_into().unwrap(),
),
};
*cur += 40;
ret.gate_consts.push(gate);
Expand Down
10 changes: 10 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ pub enum FieldType {
BN254,
}

pub const SENTINEL_M31: [u8; 32] = [
255, 255, 255, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0,
];

pub const SENTINEL_BN254: [u8; 32] = [
1, 0, 0, 240, 147, 245, 225, 67, 145, 112, 185, 121, 72, 232, 51, 40, 93, 88, 129, 129, 182,
69, 80, 184, 41, 160, 49, 225, 114, 78, 100, 48,
];

#[derive(Debug, Clone, PartialEq)]
pub enum FiatShamirHashType {
SHA256,
Expand Down
95 changes: 66 additions & 29 deletions src/exec.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
use std::{
fs,
process::exit,
sync::{Arc, Mutex},
vec,
};

use arith::{Field, FieldSerde, VectorizedM31};
use expander_rs::{Circuit, Config, Proof, Prover, Verifier};
use arith::{Field, FieldSerde, VectorizedField, VectorizedFr, VectorizedM31};
use expander_rs::{
Circuit, Config, FieldType, Proof, Prover, Verifier, SENTINEL_BN254, SENTINEL_M31,
};
use log::debug;
use warp::Filter;

type F = VectorizedM31;

fn dump_proof_and_claimed_v(proof: &Proof, claimed_v: &[F]) -> Vec<u8> {
fn dump_proof_and_claimed_v<F: Field + FieldSerde>(proof: &Proof, claimed_v: &[F]) -> Vec<u8> {
let mut bytes = Vec::new();
let proof_len = proof.bytes.len();
let claimed_v_len = claimed_v.len();
Expand All @@ -25,7 +27,7 @@ fn dump_proof_and_claimed_v(proof: &Proof, claimed_v: &[F]) -> Vec<u8> {
bytes
}

fn load_proof_and_claimed_v(bytes: &[u8]) -> (Proof, Vec<F>) {
fn load_proof_and_claimed_v<F: Field + FieldSerde>(bytes: &[u8]) -> (Proof, Vec<F>) {
let mut offset = 0;
let proof_len = u64::from_le_bytes(bytes[offset..offset + 8].try_into().unwrap()) as usize;
offset += 8;
Expand All @@ -35,7 +37,7 @@ fn load_proof_and_claimed_v(bytes: &[u8]) -> (Proof, Vec<F>) {
offset += 8;
let mut claimed_v = Vec::new();
for _ in 0..claimed_v_len {
let mut buffer = [0u8; F::SIZE];
let mut buffer = vec![0u8; F::SIZE];
buffer.copy_from_slice(&bytes[offset..offset + F::SIZE]);
offset += F::SIZE;
claimed_v.push(F::deserialize_from(&buffer));
Expand All @@ -45,30 +47,35 @@ fn load_proof_and_claimed_v(bytes: &[u8]) -> (Proof, Vec<F>) {
(proof, claimed_v)
}

#[tokio::main]
async fn main() {
// examples:
// expander-exec prove <input:circuit_file> <input:witness_file> <output:proof>
// expander-exec verify <input:circuit_file> <input:witness_file> <input:proof>
// expander-exec serve <input:circuit_file> <input:ip> <input:port>
let args = std::env::args().collect::<Vec<String>>();
if args.len() < 4 {
println!(
"Usage: expander-exec prove <input:circuit_file> <input:witness_file> <output:proof>"
);
println!(
"Usage: expander-exec verify <input:circuit_file> <input:witness_file> <input:proof>"
);
println!("Usage: expander-exec serve <input:circuit_file> <input:host> <input:port>");
return;
fn detect_field_type_from_circuit_file(circuit_file: &str) -> FieldType {
// read last 32 byte of sentinal field element to determine field type
let bytes = fs::read(circuit_file).expect("Unable to read circuit file.");
let field_bytes = &bytes[bytes.len() - 32..bytes.len()];
match field_bytes.try_into().unwrap() {
SENTINEL_M31 => FieldType::M31,
SENTINEL_BN254 => FieldType::BN254,
_ => {
println!("Unknown field type.");
exit(1);
}
}
let command = &args[1];
let circuit_file = &args[2];
match command.as_str() {
}

async fn run_command<F>(field_type: FieldType, command: &str, circuit_file: &str, args: &[String])
where
F: VectorizedField + FieldSerde + Send + 'static,
F::BaseField: Send,
F::PackedBaseField: Field<BaseField = F::BaseField>,
{
let config = match field_type {
FieldType::M31 => Config::m31_config(),
FieldType::BN254 => Config::bn254_config(),
_ => todo!("baby bear"),
};
match command {
"prove" => {
let witness_file = &args[3];
let output_file = &args[4];
let config = Config::m31_config();
let mut circuit = Circuit::<F>::load_circuit(circuit_file);
circuit.load_witness_file(witness_file);
circuit.evaluate();
Expand All @@ -81,7 +88,6 @@ async fn main() {
"verify" => {
let witness_file = &args[3];
let output_file = &args[4];
let config = Config::m31_config();
let mut circuit = Circuit::<F>::load_circuit(circuit_file);
circuit.load_witness_file(witness_file);
let bytes = fs::read(output_file).expect("Unable to read proof from file.");
Expand All @@ -98,7 +104,6 @@ async fn main() {
.try_into()
.unwrap();
let port = args[4].parse().unwrap();
let config = Config::m31_config();
let circuit = Circuit::<F>::load_circuit(circuit_file);
let mut prover = Prover::new(&config);
prover.prepare_mem(&circuit);
Expand Down Expand Up @@ -155,3 +160,35 @@ async fn main() {
}
}
}

#[tokio::main]
async fn main() {
// examples:
// expander-exec prove <input:circuit_file> <input:witness_file> <output:proof>
// expander-exec verify <input:circuit_file> <input:witness_file> <input:proof>
// expander-exec serve <input:circuit_file> <input:ip> <input:port>
let args = std::env::args().collect::<Vec<String>>();
if args.len() < 4 {
println!(
"Usage: expander-exec prove <input:circuit_file> <input:witness_file> <output:proof>"
);
println!(
"Usage: expander-exec verify <input:circuit_file> <input:witness_file> <input:proof>"
);
println!("Usage: expander-exec serve <input:circuit_file> <input:host> <input:port>");
return;
}
let command = &args[1];
let circuit_file = &args[2];
let field_type = detect_field_type_from_circuit_file(circuit_file);
debug!("field type: {:?}", field_type);
match field_type {
FieldType::M31 => {
run_command::<VectorizedM31>(field_type, command, circuit_file, &args).await;
}
FieldType::BN254 => {
run_command::<VectorizedFr>(field_type, command, circuit_file, &args).await;
}
_ => unreachable!(),
}
}

0 comments on commit 9bfee3b

Please sign in to comment.