Skip to content

Commit

Permalink
refactor: batched polynomial reads
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Jan 2, 2025
1 parent 0654e92 commit 948e8ae
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 23 deletions.
4 changes: 4 additions & 0 deletions halo2_proofs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"]
name = "arithmetic"
harness = false

[[bench]]
name = "polyread"
harness = false

[[bench]]
name = "commit_zk"
harness = false
Expand Down
91 changes: 91 additions & 0 deletions halo2_proofs/benches/polyread.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::io::Cursor;

use criterion::{black_box, criterion_group, criterion_main, Criterion};
use halo2_proofs::{poly::Polynomial, SerdeFormat, SerdePrimeField};
use halo2curves::bn256::Fr;
use maybe_rayon::{iter::ParallelIterator, slice::ParallelSlice};
use rand_core::OsRng;

pub fn parallel_poly_read_benchmark_unchecked(c: &mut Criterion) {
let mut group = c.benchmark_group("parallel_poly_read_unchecked");

for batch_size in [64, 256, 1024, 4096, 100000, 1000000].iter() {
let data = setup_random_poly(100_000_000);
group.bench_function(format!("batch_{}", batch_size), |b| {
b.iter(|| {
let mut reader = Cursor::new(data.clone());
black_box(
read::<_, Fr>(&mut reader, SerdeFormat::RawBytesUnchecked, *batch_size)
.unwrap(),
)
});
});
}

group.finish();
}

pub fn parallel_poly_read_benchmark_checked(c: &mut Criterion) {
let mut group = c.benchmark_group("parallel_poly_read_checked");

for batch_size in [64, 256, 1024, 4096, 100000, 1000000].iter() {
let data = setup_random_poly(100_000_000);
group.bench_function(format!("batch_{}", batch_size), |b| {
b.iter(|| {
let mut reader = Cursor::new(data.clone());
black_box(read::<_, Fr>(&mut reader, SerdeFormat::RawBytes, *batch_size).unwrap())
});
});
}

group.finish();
}

criterion_group!(
benches,
parallel_poly_read_benchmark_checked,
parallel_poly_read_benchmark_unchecked
);
criterion_main!(benches);

fn setup_random_poly(n: usize) -> Vec<u8> {
let mut rng = OsRng;
let random_poly = Polynomial::<Fr, usize>::random(n, &mut rng);
let mut vector_bytes = vec![];
random_poly
.write(&mut vector_bytes, SerdeFormat::RawBytes)
.unwrap();
vector_bytes
}

pub fn read<R: std::io::Read, F: SerdePrimeField>(
reader: &mut R,
format: SerdeFormat,
batch_size: usize,
) -> std::io::Result<Vec<F>> {
let poly_len = u32::from_be_bytes({
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
buf
}) as usize;

let repr_len = F::default().to_repr().as_ref().len();
let buffer = {
let mut buf = vec![0u8; poly_len * repr_len];
reader.read_exact(&mut buf)?;
buf
};

Ok(buffer
.par_chunks(repr_len * batch_size)
.map(|batch| {
batch
.chunks(repr_len)
.map(|chunk| F::read(&mut std::io::Cursor::new(chunk), format))
.collect::<Result<Vec<_>, _>>()
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.flatten()
.collect())
}
1 change: 1 addition & 0 deletions halo2_proofs/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub trait SerdeCurveAffine: CurveAffine + SerdeObject {
}
impl<C: CurveAffine + SerdeObject> SerdeCurveAffine for C {}

///
pub trait SerdePrimeField: PrimeField + SerdeObject {
/// Reads a field element as bytes from the buffer according to the `format`:
/// - `Processed`: Reads a field element in standard form, with endianness specified by the
Expand Down
1 change: 1 addition & 0 deletions halo2_proofs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ pub mod transcript;
pub mod dev;
mod helpers;
pub use helpers::SerdeFormat;
pub use helpers::SerdePrimeField;

#[cfg(feature = "icicle_gpu")]
#[allow(unsafe_code)]
Expand Down
58 changes: 35 additions & 23 deletions halo2_proofs/src/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use crate::plonk::Assigned;
use crate::SerdeFormat;
use group::ff::{BatchInvert, Field};
#[cfg(feature = "parallel-poly-read")]
use maybe_rayon::{iter::ParallelIterator, prelude::ParallelSliceMut};
use maybe_rayon::{iter::ParallelIterator, prelude::ParallelSlice};
use rand_core::RngCore;

use std::fmt::Debug;
use std::io;
Expand Down Expand Up @@ -162,35 +163,50 @@ impl<F, B> Polynomial<F, B> {
}

impl<F: SerdePrimeField, B> Polynomial<F, B> {
/// create a random polynomial with `num_coeffs` coefficients
pub fn random<R: RngCore>(num_coeffs: usize, rng: &mut R) -> Self {
Polynomial {
values: (0..num_coeffs).map(|_| F::random(&mut *rng)).collect(),
_marker: PhantomData,
}
}

/// Reads polynomial from buffer using `SerdePrimeField::read`.
#[cfg(feature = "parallel-poly-read")]
pub(crate) fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
let mut poly_len = [0u8; 4];
reader.read_exact(&mut poly_len)?;
let poly_len = u32::from_be_bytes(poly_len) as usize;
pub fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
const BATCH_SIZE: usize = 1024; // Adjusted based on testing

let repr_len = F::default().to_repr().as_ref().len();

let mut new_vals = vec![0u8; poly_len * repr_len];
reader.read_exact(&mut new_vals)?;
let poly_len = u32::from_be_bytes({
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
buf
}) as usize;

// parallel read
new_vals
.par_chunks_mut(repr_len)
.map(|chunk| {
let mut chunk = io::Cursor::new(chunk);
F::read(&mut chunk, format)
let repr_len = F::default().to_repr().as_ref().len();
let buffer = {
let mut buf = vec![0u8; poly_len * repr_len];
reader.read_exact(&mut buf)?;
buf
};

buffer
.par_chunks(repr_len * BATCH_SIZE)
.map(|batch| {
batch
.chunks(repr_len)
.map(|chunk| F::read(&mut io::Cursor::new(chunk), format))
.collect::<io::Result<Vec<_>>>()
})
.collect::<io::Result<Vec<_>>>()
.collect::<Result<Vec<_>, _>>()
.map(|values| Self {
values,
values: values.into_iter().flatten().collect(),
_marker: PhantomData,
})
}

/// Reads polynomial from buffer using `SerdePrimeField::read`.
#[cfg(not(feature = "parallel-poly-read"))]
pub(crate) fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
pub fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
let mut poly_len = [0u8; 4];
reader.read_exact(&mut poly_len)?;
let poly_len = u32::from_be_bytes(poly_len);
Expand All @@ -205,11 +221,7 @@ impl<F: SerdePrimeField, B> Polynomial<F, B> {
}

/// Writes polynomial to buffer using `SerdePrimeField::write`.
pub(crate) fn write<W: io::Write>(
&self,
writer: &mut W,
format: SerdeFormat,
) -> io::Result<()> {
pub fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
writer.write_all(&(self.values.len() as u32).to_be_bytes())?;
for value in self.values.iter() {
value.write(writer, format)?;
Expand Down

0 comments on commit 948e8ae

Please sign in to comment.