Skip to content

Commit

Permalink
refactor: faster poly slice reads
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Jan 5, 2025
1 parent d39c8e4 commit 6d72498
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 7 deletions.
63 changes: 62 additions & 1 deletion halo2_proofs/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub fn unpack(byte: u8, bits: &mut [bool]) {
}
}

#[cfg(not(feature = "parallel-poly-read"))]
/// Reads a vector of polynomials from buffer
pub(crate) fn read_polynomial_vec<R: io::Read, F: SerdePrimeField, B>(
reader: &mut R,
Expand All @@ -130,18 +131,77 @@ pub(crate) fn read_polynomial_vec<R: io::Read, F: SerdePrimeField, B>(
reader.read_exact(&mut len)?;
let len = u32::from_be_bytes(len);

let poly_lens: Result<Vec<_>, _> = (0..len)
.map(|_| {
let mut poly_len = [0u8; 4];
reader.read_exact(&mut poly_len)?;
Ok::<_, std::io::Error>(u32::from_be_bytes(poly_len))
})
.collect();

let _poly_lens = poly_lens?;

(0..len)
.map(|_| Polynomial::<F, B>::read(reader, format))
.collect::<io::Result<Vec<_>>>()
}

#[cfg(feature = "parallel-poly-read")]
/// Reads a vector of polynomials from buffer
pub(crate) fn read_polynomial_vec<R: io::Read, F: SerdePrimeField, B: std::marker::Send>(
reader: &mut R,
format: SerdeFormat,
) -> io::Result<Vec<Polynomial<F, B>>> {
use maybe_rayon::iter::IntoParallelIterator;
use maybe_rayon::iter::ParallelIterator;

let mut len = [0u8; 4];
reader.read_exact(&mut len)?;
let len = u32::from_be_bytes(len);

// Read all polynomial lengths first
let mut poly_lens = Vec::with_capacity(len as usize);
for _ in 0..len {
let mut poly_len = [0u8; 4];
reader.read_exact(&mut poly_len)?;
poly_lens.push(u32::from_be_bytes(poly_len));
}

println!("poly_lens {:?}", poly_lens);

// Pre-read all polynomial data into separate buffers
let mut poly_buffers = Vec::with_capacity(len as usize);
for &poly_len in &poly_lens {
let repr_len = F::default().to_repr().as_ref().len();
// sum of all the Field elements AND also the prepended u32 bytes
let buffer_size = repr_len * poly_len as usize + std::mem::size_of::<u32>();
let mut buffer = vec![0u8; buffer_size];
reader.read_exact(&mut buffer)?;
poly_buffers.push(buffer);
}

// Process buffers in parallel
poly_buffers
.into_par_iter()
.map(|buffer| {
let mut cursor = std::io::Cursor::new(buffer);
Polynomial::<F, B>::read_serial(&mut cursor, format)
})
.collect::<io::Result<Vec<_>>>()
}

/// Writes a slice of polynomials to buffer
pub(crate) fn write_polynomial_slice<W: io::Write, F: SerdePrimeField, B>(
slice: &[Polynomial<F, B>],
writer: &mut W,
format: SerdeFormat,
) -> io::Result<()> {
writer.write_all(&(slice.len() as u32).to_be_bytes())?;
// then write each polynomial's len
for poly in slice.iter() {
writer.write_all(&(poly.num_coeffs() as u32).to_be_bytes())?;
}

for poly in slice.iter() {
poly.write(writer, format)?;
}
Expand All @@ -151,5 +211,6 @@ pub(crate) fn write_polynomial_slice<W: io::Write, F: SerdePrimeField, B>(
/// Gets the total number of bytes of a slice of polynomials, assuming all polynomials are the same length
pub(crate) fn polynomial_slice_byte_length<F: PrimeField, B>(slice: &[Polynomial<F, B>]) -> usize {
let field_len = F::default().to_repr().as_ref().len();
4 + slice.len() * (4 + field_len * slice.first().map(|poly| poly.len()).unwrap_or(0))
4 + 4 * slice.len()
+ slice.len() * (4 + field_len * slice.first().map(|poly| poly.len()).unwrap_or(0))
}
5 changes: 5 additions & 0 deletions halo2_proofs/src/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,11 @@ impl<F: SerdePrimeField, B> Polynomial<F, B> {
/// Reads polynomial from buffer using `SerdePrimeField::read`.
#[cfg(not(feature = "parallel-poly-read"))]
pub fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
Self::read_serial(reader, format)
}

/// Reads polynomial from buffer using `SerdePrimeField::read`.
pub fn read_serial<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
let mut poly_len = [0u8; 4];
reader.read_exact(&mut poly_len)?;
let poly_len = u32::from_be_bytes(poly_len);
Expand Down
17 changes: 11 additions & 6 deletions halo2_proofs/src/poly/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use crate::{
use super::{Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation};

use group::ff::{BatchInvert, Field, WithSmallOrderMulGroup};
use maybe_rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};

Check failure on line 13 in halo2_proofs/src/poly/domain.rs

View workflow job for this annotation

GitHub Actions / Build target wasm32-wasi

unresolved import `maybe_rayon::iter::IndexedParallelIterator`

Check failure on line 13 in halo2_proofs/src/poly/domain.rs

View workflow job for this annotation

GitHub Actions / Build target wasm32-unknown-unknown

unresolved import `maybe_rayon::iter::IndexedParallelIterator`

Check warning on line 13 in halo2_proofs/src/poly/domain.rs

View workflow job for this annotation

GitHub Actions / Build target wasm32-unknown-unknown

unused import: `ParallelIterator`

use std::{collections::HashMap, marker::PhantomData};

Expand Down Expand Up @@ -142,12 +143,16 @@ impl<F: WithSmallOrderMulGroup<3>> EvaluationDomain<F> {

let omega_inv = omegas_inv[0];
let extended_omega_inv = *omegas_inv.last().unwrap();
let mut fft_data = HashMap::new();
for (i, (omega, omega_inv)) in omegas.into_iter().zip(omegas_inv).enumerate() {
let intermediate_k = k as usize + i;
let len = 1usize << intermediate_k;
fft_data.insert(len, FFTData::<F>::new(len, omega, omega_inv));
}
let fft_data = omegas
.into_par_iter()
.zip(omegas_inv)
.enumerate()
.map(|(i, (omega, omega_inv))| {
let intermediate_k = k as usize + i;
let len = 1usize << intermediate_k;
(len, FFTData::<F>::new(len, omega, omega_inv))
})
.collect::<HashMap<usize, FFTData<F>>>();

EvaluationDomain {
n,
Expand Down

0 comments on commit 6d72498

Please sign in to comment.