Skip to content

Commit

Permalink
feat: batch sized recursion (#785)
Browse files Browse the repository at this point in the history
Co-authored-by: Ratan Kaliani <ratankaliani@berkeley.edu>
  • Loading branch information
jtguibas and ratankaliani authored May 23, 2024
1 parent d7e1851 commit 8af2e2f
Showing 1 changed file with 88 additions and 71 deletions.
159 changes: 88 additions & 71 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ impl SP1Prover {
// Run the recursion and reduce programs.

// Run the recursion programs.
let mut records = Vec::new();

let (core_inputs, deferred_inputs) = self.get_first_layer_inputs(
vk,
Expand All @@ -398,62 +397,71 @@ impl SP1Prover {
batch_size,
);

for input in core_inputs {
let mut runtime = RecursionRuntime::<Val<InnerSC>, Challenge<InnerSC>, _>::new(
&self.recursion_program,
self.compress_machine.config().perm.clone(),
);

let mut witness_stream = Vec::new();
witness_stream.extend(input.write());

runtime.witness_stream = witness_stream.into();
runtime.run();
runtime.print_stats();
let mut first_layer_proofs = Vec::new();
let shard_batch_size = sp1_core::utils::env::shard_batch_size() as usize;
for inputs in core_inputs.chunks(shard_batch_size) {
let proofs = inputs
.into_par_iter()
.map(|input| {
let mut runtime = RecursionRuntime::<Val<InnerSC>, Challenge<InnerSC>, _>::new(
&self.recursion_program,
self.compress_machine.config().perm.clone(),
);

records.push((runtime.record, ReduceProgramType::Core));
let mut witness_stream = Vec::new();
witness_stream.extend(input.write());

runtime.witness_stream = witness_stream.into();
runtime.run();
runtime.print_stats();

let pk = &self.rec_pk;
let mut recursive_challenger = self.compress_machine.config().challenger();
(
self.compress_machine.prove::<LocalProver<_, _>>(
pk,
runtime.record,
&mut recursive_challenger,
),
ReduceProgramType::Core,
)
})
.collect::<Vec<_>>();
first_layer_proofs.extend(proofs);
}

// Run the deferred proofs programs.
for input in deferred_inputs {
let mut runtime = RecursionRuntime::<Val<InnerSC>, Challenge<InnerSC>, _>::new(
&self.deferred_program,
self.compress_machine.config().perm.clone(),
);

let mut witness_stream = Vec::new();
witness_stream.extend(input.write());

runtime.witness_stream = witness_stream.into();
runtime.run();
runtime.print_stats();
for inputs in deferred_inputs.chunks(shard_batch_size) {
let proofs = inputs
.into_par_iter()
.map(|input| {
let mut runtime = RecursionRuntime::<Val<InnerSC>, Challenge<InnerSC>, _>::new(
&self.deferred_program,
self.compress_machine.config().perm.clone(),
);

records.push((runtime.record, ReduceProgramType::Deferred));
let mut witness_stream = Vec::new();
witness_stream.extend(input.write());

runtime.witness_stream = witness_stream.into();
runtime.run();
runtime.print_stats();

let pk = &self.deferred_pk;
let mut recursive_challenger = self.compress_machine.config().challenger();
(
self.compress_machine.prove::<LocalProver<_, _>>(
pk,
runtime.record,
&mut recursive_challenger,
),
ReduceProgramType::Deferred,
)
})
.collect::<Vec<_>>();
first_layer_proofs.extend(proofs);
}

// Prove all recursion programs and recursion deferred programs and verify the proofs.

// Make the recursive proofs for core and deferred proofs.
let first_layer_proofs = records
.into_par_iter()
.map(|(record, kind)| {
let pk = match kind {
ReduceProgramType::Core => &self.rec_pk,
ReduceProgramType::Deferred => &self.deferred_pk,
ReduceProgramType::Reduce => unreachable!(),
};
let mut recursive_challenger = self.compress_machine.config().challenger();
(
self.compress_machine.prove::<LocalProver<_, _>>(
pk,
record,
&mut recursive_challenger,
),
kind,
)
})
.collect::<Vec<_>>();

// Chain all the individual shard proofs.
let mut reduce_proofs = first_layer_proofs
.into_iter()
Expand All @@ -465,28 +473,37 @@ impl SP1Prover {
loop {
tracing::debug!("Recursive proof layer size: {}", reduce_proofs.len());
is_complete = reduce_proofs.len() <= batch_size;
reduce_proofs = reduce_proofs
.par_chunks(batch_size)
.map(|batch| {
let (shard_proofs, kinds) =
batch.iter().cloned().unzip::<_, _, Vec<_>, Vec<_>>();

let input = SP1ReduceMemoryLayout {
compress_vk: &self.compress_vk,
recursive_machine: &self.compress_machine,
shard_proofs,
kinds,
is_complete,
};

let proof = self.compress_machine_proof(
input,
&self.compress_program,
&self.compress_pk,
);
(proof, ReduceProgramType::Reduce)

let compress_inputs = reduce_proofs.chunks(batch_size).collect::<Vec<_>>();
let batched_compress_inputs =
compress_inputs.chunks(shard_batch_size).collect::<Vec<_>>();
reduce_proofs = batched_compress_inputs
.into_iter()
.flat_map(|batches| {
batches
.par_iter()
.map(|batch| {
let (shard_proofs, kinds) =
batch.iter().cloned().unzip::<_, _, Vec<_>, Vec<_>>();

let input = SP1ReduceMemoryLayout {
compress_vk: &self.compress_vk,
recursive_machine: &self.compress_machine,
shard_proofs,
kinds,
is_complete,
};

let proof = self.compress_machine_proof(
input,
&self.compress_program,
&self.compress_pk,
);
(proof, ReduceProgramType::Reduce)
})
.collect::<Vec<_>>()
})
.collect();
.collect::<Vec<_>>();

if reduce_proofs.len() == 1 {
break;
Expand Down

0 comments on commit 8af2e2f

Please sign in to comment.