diff --git a/halo2_proofs/src/helpers.rs b/halo2_proofs/src/helpers.rs index 051d18061..1dc01bbce 100644 --- a/halo2_proofs/src/helpers.rs +++ b/halo2_proofs/src/helpers.rs @@ -121,6 +121,7 @@ pub fn unpack(byte: u8, bits: &mut [bool]) { } } +#[cfg(not(feature = "parallel-poly-read"))] /// Reads a vector of polynomials from buffer pub(crate) fn read_polynomial_vec( reader: &mut R, @@ -130,11 +131,65 @@ pub(crate) fn read_polynomial_vec( reader.read_exact(&mut len)?; let len = u32::from_be_bytes(len); + let poly_lens: Result, _> = (0..len) + .map(|_| { + let mut poly_len = [0u8; 4]; + reader.read_exact(&mut poly_len)?; + Ok::<_, std::io::Error>(u32::from_be_bytes(poly_len)) + }) + .collect(); + + let _poly_lens = poly_lens?; + (0..len) .map(|_| Polynomial::::read(reader, format)) .collect::>>() } +#[cfg(feature = "parallel-poly-read")] +/// Reads a vector of polynomials from buffer +pub(crate) fn read_polynomial_vec( + reader: &mut R, + format: SerdeFormat, +) -> io::Result>> { + use maybe_rayon::iter::IntoParallelIterator; + use maybe_rayon::iter::ParallelIterator; + + let mut len = [0u8; 4]; + reader.read_exact(&mut len)?; + let len = u32::from_be_bytes(len); + + // Read all polynomial lengths first + let mut poly_lens = Vec::with_capacity(len as usize); + for _ in 0..len { + let mut poly_len = [0u8; 4]; + reader.read_exact(&mut poly_len)?; + poly_lens.push(u32::from_be_bytes(poly_len)); + } + + println!("poly_lens {:?}", poly_lens); + + // Pre-read all polynomial data into separate buffers + let mut poly_buffers = Vec::with_capacity(len as usize); + for &poly_len in &poly_lens { + let repr_len = F::default().to_repr().as_ref().len(); + // sum of all the Field elements AND also the prepended u32 bytes + let buffer_size = repr_len * poly_len as usize + std::mem::size_of::(); + let mut buffer = vec![0u8; buffer_size]; + reader.read_exact(&mut buffer)?; + poly_buffers.push(buffer); + } + + // Process buffers in parallel + poly_buffers + .into_par_iter() + .map(|buffer| { + let mut cursor = std::io::Cursor::new(buffer); + Polynomial::::read_serial(&mut cursor, format) + }) + .collect::>>() +} + /// Writes a slice of polynomials to buffer pub(crate) fn write_polynomial_slice( slice: &[Polynomial], @@ -142,6 +197,11 @@ pub(crate) fn write_polynomial_slice( format: SerdeFormat, ) -> io::Result<()> { writer.write_all(&(slice.len() as u32).to_be_bytes())?; + // then write each polynomial's len + for poly in slice.iter() { + writer.write_all(&(poly.num_coeffs() as u32).to_be_bytes())?; + } + for poly in slice.iter() { poly.write(writer, format)?; } @@ -151,5 +211,6 @@ pub(crate) fn write_polynomial_slice( /// Gets the total number of bytes of a slice of polynomials, assuming all polynomials are the same length pub(crate) fn polynomial_slice_byte_length(slice: &[Polynomial]) -> usize { let field_len = F::default().to_repr().as_ref().len(); - 4 + slice.len() * (4 + field_len * slice.first().map(|poly| poly.len()).unwrap_or(0)) + 4 + 4 * slice.len() + + slice.len() * (4 + field_len * slice.first().map(|poly| poly.len()).unwrap_or(0)) } diff --git a/halo2_proofs/src/poly.rs b/halo2_proofs/src/poly.rs index 2076eaff2..e0cff4520 100644 --- a/halo2_proofs/src/poly.rs +++ b/halo2_proofs/src/poly.rs @@ -201,6 +201,11 @@ impl Polynomial { /// Reads polynomial from buffer using `SerdePrimeField::read`. #[cfg(not(feature = "parallel-poly-read"))] pub fn read(reader: &mut R, format: SerdeFormat) -> io::Result { + Self::read_serial(reader, format) + } + + /// Reads polynomial from buffer using `SerdePrimeField::read`. + pub fn read_serial(reader: &mut R, format: SerdeFormat) -> io::Result { let mut poly_len = [0u8; 4]; reader.read_exact(&mut poly_len)?; let poly_len = u32::from_be_bytes(poly_len); diff --git a/halo2_proofs/src/poly/domain.rs b/halo2_proofs/src/poly/domain.rs index 71f214423..a9f5b8b65 100644 --- a/halo2_proofs/src/poly/domain.rs +++ b/halo2_proofs/src/poly/domain.rs @@ -10,6 +10,7 @@ use crate::{ use super::{Coeff, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation}; use group::ff::{BatchInvert, Field, WithSmallOrderMulGroup}; +use maybe_rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator}; use std::{collections::HashMap, marker::PhantomData}; @@ -142,12 +143,16 @@ impl> EvaluationDomain { let omega_inv = omegas_inv[0]; let extended_omega_inv = *omegas_inv.last().unwrap(); - let mut fft_data = HashMap::new(); - for (i, (omega, omega_inv)) in omegas.into_iter().zip(omegas_inv).enumerate() { - let intermediate_k = k as usize + i; - let len = 1usize << intermediate_k; - fft_data.insert(len, FFTData::::new(len, omega, omega_inv)); - } + let fft_data = omegas + .into_par_iter() + .zip(omegas_inv) + .enumerate() + .map(|(i, (omega, omega_inv))| { + let intermediate_k = k as usize + i; + let len = 1usize << intermediate_k; + (len, FFTData::::new(len, omega, omega_inv)) + }) + .collect::>>(); EvaluationDomain { n,