Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: parallel recursion tracegen #1095

Merged
merged 10 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions core/src/runtime/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,6 @@ impl ExecutionRecord {
pub fn split(&mut self, last: bool, opts: SplitOpts) -> Vec<ExecutionRecord> {
let mut shards = Vec::new();

println!("keccak split {}", opts.keccak_split_threshold);

macro_rules! split_events {
($self:ident, $events:ident, $shards:ident, $threshold:expr, $exact:expr) => {
let events = std::mem::take(&mut $self.$events);
Expand Down
48 changes: 38 additions & 10 deletions core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub use programs::*;

use crate::{memory::MemoryCols, operations::field::params::Limbs};
use generic_array::ArrayLength;
use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator};

pub const fn indices_arr<const N: usize>() -> [usize; N] {
let mut indices_arr = [0; N];
Expand Down Expand Up @@ -88,30 +89,36 @@ pub fn pad_rows_fixed<R: Clone>(
) {
let nb_rows = rows.len();
let dummy_row = row_fn();
match size_log2 {
Some(size_log2) => {
let padded_nb_rows = 1 << size_log2;
if nb_rows * 2 < padded_nb_rows {
rows.resize(next_power_of_two(nb_rows, size_log2), dummy_row);
}

/// Returns the next power of two that is >= `n` and >= 16. If `fixed_power` is set, it will return
/// `2^fixed_power` after checking that `n <= 2^fixed_power`.
pub fn next_power_of_two(n: usize, fixed_power: Option<usize>) -> usize {
match fixed_power {
Some(power) => {
let padded_nb_rows = 1 << power;
if n * 2 < padded_nb_rows {
tracing::warn!(
"fixed log2 rows can be potentially reduced: got {}, expected {}",
nb_rows,
n,
padded_nb_rows
);
}
if nb_rows > padded_nb_rows {
if n > padded_nb_rows {
panic!(
"fixed log2 rows is too small: got {}, expected {}",
nb_rows, padded_nb_rows
n, padded_nb_rows
);
}
rows.resize(padded_nb_rows, dummy_row);
padded_nb_rows
}
None => {
let mut padded_nb_rows = nb_rows.next_power_of_two();
let mut padded_nb_rows = n.next_power_of_two();
if padded_nb_rows < 16 {
padded_nb_rows = 16;
}
rows.resize(padded_nb_rows, dummy_row);
padded_nb_rows
}
}
}
Expand Down Expand Up @@ -186,3 +193,24 @@ pub fn log2_strict_usize(n: usize) -> usize {
assert_eq!(n.wrapping_shr(res), 1, "Not a power of two: {n}");
res as usize
}

pub fn par_for_each_row<P, F>(vec: &mut [F], num_cols: usize, processor: P)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OO nice.

where
F: Send,
P: Fn(usize, &mut [F]) + Send + Sync,
{
// Split the vector into `num_cpus` chunks, but at least `num_cpus` rows per chunk.
let len = vec.len();
let cpus = num_cpus::get();
let ceil_div = (len + cpus - 1) / cpus;
let chunk_size = std::cmp::max(ceil_div, cpus);

vec.chunks_mut(chunk_size * num_cols)
.enumerate()
.par_bridge()
.for_each(|(i, chunk)| {
chunk.chunks_mut(num_cols).enumerate().for_each(|(j, row)| {
processor(i * chunk_size + j, row);
});
});
}
8 changes: 2 additions & 6 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
pub use sp1_recursion_program::machine::{
SP1DeferredMemoryLayout, SP1RecursionMemoryLayout, SP1ReduceMemoryLayout, SP1RootMemoryLayout,
};
use tracing::instrument;
use tracing::{info_span, instrument};

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test (x86-64)

unused import: `info_span`

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test (x86-64)

unused import: `info_span`

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test (ARM)

unused import: `info_span`

Check warning on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Test (ARM)

unused import: `info_span`

Check failure on line 62 in prover/src/lib.rs

View workflow job for this annotation

GitHub Actions / Formatting & Clippy

unused import: `info_span`
pub use types::*;
use utils::words_to_bytes;

Expand Down Expand Up @@ -295,10 +295,6 @@
for batch in shard_proofs.chunks(batch_size) {
let proofs = batch.to_vec();

let public_values: &PublicValues<Word<BabyBear>, BabyBear> =
proofs.last().unwrap().public_values.as_slice().borrow();
println!("core execution shard: {}", public_values.execution_shard);

core_inputs.push(SP1RecursionMemoryLayout {
vk,
machine: self.core_prover.machine(),
Expand Down Expand Up @@ -517,6 +513,7 @@
})
}

/// Generate a proof with the compress machine.
pub fn compress_machine_proof(
&self,
input: impl Hintable<InnerConfig>,
Expand All @@ -533,7 +530,6 @@
witness_stream.extend(input.write());

runtime.witness_stream = witness_stream.into();

runtime
.run()
.map_err(|e| SP1RecursionProverError::RuntimeError(e.to_string()))?;
Expand Down
1 change: 1 addition & 0 deletions recursion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ serde_with = "3.8.3"
backtrace = { version = "0.3.71", features = ["serde"] }
arrayref = "0.3.7"
static_assertions = "1.1.0"
num_cpus = "1.16.0"

[dev-dependencies]
rand = "0.8.5"
10 changes: 1 addition & 9 deletions recursion/core/src/cpu/columns/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::mem::{size_of, transmute};
use std::mem::size_of;

use crate::memory::{MemoryReadCols, MemoryReadWriteCols};
use p3_air::BaseAir;
use sp1_core::utils::indices_arr;
use sp1_derive::AlignedBorrow;

mod branch;
Expand All @@ -23,13 +22,6 @@ use super::CpuChip;

pub const NUM_CPU_COLS: usize = size_of::<CpuCols<u8>>();

const fn make_col_map() -> CpuCols<usize> {
let indices_arr = indices_arr::<NUM_CPU_COLS>();
unsafe { transmute::<[usize; NUM_CPU_COLS], CpuCols<usize>>(indices_arr) }
}

pub(crate) const CPU_COL_MAP: CpuCols<usize> = make_col_map();

impl<F: Send + Sync, const L: usize> BaseAir<F> for CpuChip<F, L> {
fn width(&self) -> usize {
NUM_CPU_COLS
Expand Down
Loading
Loading