Skip to content

Commit

Permalink
main simplified
Browse files Browse the repository at this point in the history
  • Loading branch information
sixbigsquare committed Jul 8, 2024
1 parent 381fa80 commit af04269
Showing 1 changed file with 13 additions and 73 deletions.
86 changes: 13 additions & 73 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
thread,
};

use arith::{Field, VectorizedField, VectorizedFr, VectorizedM31, M31};
use arith::{Field, FieldSerde, VectorizedField, VectorizedFr, VectorizedM31, M31};
use clap::Parser;
use expander_rs::{Circuit, Config, Prover};
use halo2curves::bn256::Fr;
Expand Down Expand Up @@ -33,15 +33,20 @@ fn main() {
let args = Args::parse();
print_info(&args);
match args.field {
31 => run_keccak_bench_m31(&args),
254 => run_keccak_bench_fr(&args),
31 => run_keccak_bench::<M31, VectorizedM31>(&args, Config::m31_config()),
254 => run_keccak_bench::<Fr, VectorizedFr>(&args, Config::bn254_config()),
_ => unreachable!(),
}
}

fn run_keccak_bench_m31(args: &Args) {
let config = Config::m31_config();
println!("benchmarking keccak over {}", M31::NAME);
fn run_keccak_bench<F, VecF>(args: &Args, config: Config)
where
F: Field,
VecF: VectorizedField + FieldSerde + Send + 'static,
VecF::BaseField: Send,
VecF::PackedBaseField: Field<BaseField = VecF::BaseField>,
{
println!("benchmarking keccak over {}", F::NAME);
println!(
"Default parallel repetition config {}",
config.get_num_repetitions()
Expand All @@ -53,72 +58,7 @@ fn run_keccak_bench_m31(args: &Args) {
let start_time = std::time::Instant::now();

// load circuit
let circuit_template =
Circuit::<VectorizedM31>::load_extracted_gates(FILENAME_MUL, FILENAME_ADD);
let circuits = (0..args.threads)
.map(|_| {
let mut c = circuit_template.clone();
c.set_random_bool_input_for_test();
c.evaluate();
c
})
.collect::<Vec<_>>();

println!("Circuit loaded!");
let _ = circuits
.into_iter()
.enumerate()
.map(|(i, c)| {
let partial_proof_cnt = partial_proof_cnts[i].clone();
let local_config = config.clone();
thread::spawn(move || {
loop {
// bench func
let mut prover = Prover::new(&local_config);
prover.prepare_mem(&c);
prover.prove(&c);
// update cnt
let mut cnt = partial_proof_cnt.lock().unwrap();
const CIRCUIT_COPY_SIZE: usize = 8;
let proof_cnt_this_round = CIRCUIT_COPY_SIZE
* VectorizedM31::PACK_SIZE
* VectorizedM31::VECTORIZE_SIZE;
*cnt += proof_cnt_this_round;
}
})
})
.collect::<Vec<_>>();

println!("We are now calculating average throughput, please wait for 1 minutes");
for i in 0..args.repeats {
thread::sleep(std::time::Duration::from_secs(60));
let stop_time = std::time::Instant::now();
let duration = stop_time.duration_since(start_time);
let mut total_proof_cnt = 0;
for cnt in &partial_proof_cnts {
total_proof_cnt += *cnt.lock().unwrap();
}
let throughput = total_proof_cnt as f64 / duration.as_secs_f64();
println!("{}-bench: throughput: {} keccaks/s", i, throughput.round());
}
}

fn run_keccak_bench_fr(args: &Args) {
let config = Config::bn254_config();
println!("benchmarking keccak over {}", Fr::NAME);
println!(
"Default parallel repetition config {}",
config.get_num_repetitions()
);

let partial_proof_cnts = (0..args.threads)
.map(|_| Arc::new(Mutex::new(0)))
.collect::<Vec<_>>();
let start_time = std::time::Instant::now();

// load circuit
let circuit_template =
Circuit::<VectorizedFr>::load_extracted_gates(FILENAME_MUL, FILENAME_ADD);
let circuit_template = Circuit::<VecF>::load_extracted_gates(FILENAME_MUL, FILENAME_ADD);
let circuits = (0..args.threads)
.map(|_| {
let mut c = circuit_template.clone();
Expand All @@ -145,7 +85,7 @@ fn run_keccak_bench_fr(args: &Args) {
let mut cnt = partial_proof_cnt.lock().unwrap();
const CIRCUIT_COPY_SIZE: usize = 8;
let proof_cnt_this_round =
CIRCUIT_COPY_SIZE * VectorizedFr::PACK_SIZE * VectorizedFr::VECTORIZE_SIZE;
CIRCUIT_COPY_SIZE * VecF::PACK_SIZE * VecF::VECTORIZE_SIZE;
*cnt += proof_cnt_this_round;
}
})
Expand Down

0 comments on commit af04269

Please sign in to comment.