Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: multi-threaded tracing #1124

Merged
merged 24 commits into from
Jul 18, 2024
12 changes: 7 additions & 5 deletions core/src/stark/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,14 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
let chips = self.chips();
records.iter_mut().for_each(|record| {
chips.iter().for_each(|chip| {
let mut output = A::Record::default();
chip.generate_dependencies(record, &mut output);
record.append(&mut output);
tracing::debug_span!("chip dependencies", chip = chip.name()).in_scope(|| {
let mut output = A::Record::default();
chip.generate_dependencies(record, &mut output);
record.append(&mut output);
});
});
record.register_nonces(opts);
});
tracing::debug_span!("register nonces").in_scope(|| record.register_nonces(opts));
})
}

pub const fn config(&self) -> &SC {
Expand Down
253 changes: 139 additions & 114 deletions core/src/utils/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ where
let (pk, vk) = prover.setup(runtime.program.as_ref());

// Execute the program, saving checkpoints at the start of every `shard_batch_size` cycle range.
let make_checkpoint_span = tracing::debug_span!("Execute and save checkpoints").entered();
let create_checkpoints_span = tracing::debug_span!("create checkpoints").entered();
let mut checkpoints = Vec::new();
let (public_values_stream, public_values) = loop {
// Execute the runtime until we reach a checkpoint.
Expand All @@ -174,94 +174,109 @@ where
);
}
};
make_checkpoint_span.exit();
create_checkpoints_span.exit();

// Commit to the shards.
#[cfg(debug_assertions)]
let mut debug_records: Vec<ExecutionRecord> = Vec::new();

let commit_span = tracing::debug_span!("Commit to shards").entered();
let mut deferred = ExecutionRecord::new(program.clone().into());
let mut state = public_values.reset();
let nb_checkpoints = checkpoints.len();
let mut challenger = prover.config().challenger();
vk.observe_into(&mut challenger);

let scope_span = tracing::Span::current().clone();
std::thread::scope(move |s| {
let _span = scope_span.enter();

// Spawn a thread for commiting to the shards.
let span = tracing::Span::current().clone();
let (records_tx, records_rx) =
sync_channel::<Vec<ExecutionRecord>>(opts.commit_stream_capacity);
let challenger_handle = s.spawn(move || {
for records in records_rx.iter() {
let commitments = records
.par_iter()
.map(|record| prover.commit(record))
.collect::<Vec<_>>();
for (commit, record) in commitments.into_iter().zip(records) {
prover.update(
&mut challenger,
commit,
&record.public_values::<SC::Val>()[0..prover.machine().num_pv_elts()],
);
let _span = span.enter();
tracing::debug_span!("phase 1 commiter").in_scope(|| {
for records in records_rx.iter() {
let commitments = tracing::debug_span!("batch").in_scope(|| {
let span = tracing::Span::current().clone();
records
.par_iter()
.map(|record| {
let _span = span.enter();
prover.commit(record)
})
.collect::<Vec<_>>()
});
for (commit, record) in commitments.into_iter().zip(records) {
prover.update(
&mut challenger,
commit,
&record.public_values::<SC::Val>()[0..prover.machine().num_pv_elts()],
);
}
}
}
});

challenger
});

for (checkpoint_idx, checkpoint_file) in checkpoints.iter_mut().enumerate() {
// Trace the checkpoint and reconstruct the execution records.
let (mut records, _) = trace_checkpoint(program.clone(), checkpoint_file, opts);
reset_seek(&mut *checkpoint_file);

// Update the public values & prover state for the shards which contain "cpu events".
for record in records.iter_mut() {
state.shard += 1;
state.execution_shard = record.public_values.execution_shard;
state.start_pc = record.public_values.start_pc;
state.next_pc = record.public_values.next_pc;
record.public_values = state;
}
tracing::debug_span!("phase 1 record generator").in_scope(|| {
for (checkpoint_idx, checkpoint_file) in checkpoints.iter_mut().enumerate() {
// Trace the checkpoint and reconstruct the execution records.
let (mut records, _) = tracing::debug_span!("trace checkpoint")
.in_scope(|| trace_checkpoint(program.clone(), checkpoint_file, opts));
reset_seek(&mut *checkpoint_file);

// Update the public values & prover state for the shards which contain "cpu events".
for record in records.iter_mut() {
state.shard += 1;
state.execution_shard = record.public_values.execution_shard;
state.start_pc = record.public_values.start_pc;
state.next_pc = record.public_values.next_pc;
record.public_values = state;
}

// Generate the dependencies.
tracing::debug_span!("Generate dependencies", checkpoint_idx = checkpoint_idx)
.in_scope(|| prover.machine().generate_dependencies(&mut records, &opts));
// Generate the dependencies.
tracing::debug_span!("generate dependencies")
.in_scope(|| prover.machine().generate_dependencies(&mut records, &opts));

// Defer events that are too expensive to include in every shard.
for record in records.iter_mut() {
deferred.append(&mut record.defer());
}
// Defer events that are too expensive to include in every shard.
for record in records.iter_mut() {
deferred.append(&mut record.defer());
}

// See if any deferred shards are ready to be commited to.
let is_last_checkpoint = checkpoint_idx == nb_checkpoints - 1;
let mut deferred = deferred.split(is_last_checkpoint, opts.split_opts);
// See if any deferred shards are ready to be commited to.
let is_last_checkpoint = checkpoint_idx == nb_checkpoints - 1;
let mut deferred = deferred.split(is_last_checkpoint, opts.split_opts);

// Update the public values & prover state for the shards which do not contain "cpu events"
// before committing to them.
if !is_last_checkpoint {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits = record.public_values.previous_init_addr_bits;
state.last_init_addr_bits = record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits = record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = state;
}
records.append(&mut deferred);
// Update the public values & prover state for the shards which do not contain "cpu events"
// before committing to them.
if !is_last_checkpoint {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits = record.public_values.previous_init_addr_bits;
state.last_init_addr_bits = record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits = record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = state;
}
records.append(&mut deferred);

#[cfg(debug_assertions)]
{
debug_records.extend(records.clone());
}
#[cfg(debug_assertions)]
{
debug_records.extend(records.clone());
}

records_tx.send(records).unwrap();
}
records_tx.send(records).unwrap();
}
});
drop(records_tx);
let challenger = challenger_handle.join().unwrap();
commit_span.exit();

// Debug the constraints if debug assertions are enabled.
#[cfg(debug_assertions)]
Expand All @@ -279,65 +294,76 @@ where
let (records_tx, records_rx) =
sync_channel::<Vec<ExecutionRecord>>(opts.prove_stream_capacity);

let commit_and_open = tracing::Span::current().clone();
let shard_proofs = s.spawn(move || {
let _span = commit_and_open.enter();
let mut shard_proofs = Vec::new();
for records in records_rx.iter() {
shard_proofs.par_extend(records.into_par_iter().map(|record| {
prover
.commit_and_open(&pk, record, &mut challenger.clone())
.unwrap()
}));
}
tracing::debug_span!("phase 2 prover").in_scope(|| {
for records in records_rx.iter() {
tracing::debug_span!("batch").in_scope(|| {
let span = tracing::Span::current().clone();
shard_proofs.par_extend(records.into_par_iter().map(|record| {
let _span = span.enter();
prover
.commit_and_open(&pk, record, &mut challenger.clone())
.unwrap()
}));
});
}
});
shard_proofs
});

// let mut shard_proofs = Vec::new();
for (checkpoint_idx, mut checkpoint_file) in checkpoints.into_iter().enumerate() {
// Trace the checkpoint and reconstruct the execution records.
let (mut records, report) = trace_checkpoint(program.clone(), &checkpoint_file, opts);
report_aggregate += report;
reset_seek(&mut checkpoint_file);

// Update the public values & prover state for the shards which contain "cpu events".
for record in records.iter_mut() {
state.shard += 1;
state.execution_shard = record.public_values.execution_shard;
state.start_pc = record.public_values.start_pc;
state.next_pc = record.public_values.next_pc;
record.public_values = state;
}
tracing::debug_span!("phase 2 record generator").in_scope(|| {
for (checkpoint_idx, mut checkpoint_file) in checkpoints.into_iter().enumerate() {
// Trace the checkpoint and reconstruct the execution records.
let (mut records, report) = tracing::debug_span!("trace checkpoint")
.in_scope(|| trace_checkpoint(program.clone(), &checkpoint_file, opts));
report_aggregate += report;
reset_seek(&mut checkpoint_file);

// Update the public values & prover state for the shards which contain "cpu events".
for record in records.iter_mut() {
state.shard += 1;
state.execution_shard = record.public_values.execution_shard;
state.start_pc = record.public_values.start_pc;
state.next_pc = record.public_values.next_pc;
record.public_values = state;
}

// Generate the dependencies.
prover.machine().generate_dependencies(&mut records, &opts);
// Generate the dependencies.
tracing::debug_span!("generate dependencies")
.in_scope(|| prover.machine().generate_dependencies(&mut records, &opts));

// Defer events that are too expensive to include in every shard.
for record in records.iter_mut() {
deferred.append(&mut record.defer());
}
// Defer events that are too expensive to include in every shard.
for record in records.iter_mut() {
deferred.append(&mut record.defer());
}

// See if any deferred shards are ready to be commited to.
let is_last_checkpoint = checkpoint_idx == nb_checkpoints - 1;
let mut deferred = deferred.split(is_last_checkpoint, opts.split_opts);
// See if any deferred shards are ready to be commited to.
let is_last_checkpoint = checkpoint_idx == nb_checkpoints - 1;
let mut deferred = deferred.split(is_last_checkpoint, opts.split_opts);

// Update the public values & prover state for the shards which do not contain "cpu events"
// before committing to them.
if !is_last_checkpoint {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits = record.public_values.previous_init_addr_bits;
state.last_init_addr_bits = record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits = record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = state;
}
records.append(&mut deferred);
// Update the public values & prover state for the shards which do not contain "cpu events"
// before committing to them.
if !is_last_checkpoint {
state.execution_shard += 1;
}
for record in deferred.iter_mut() {
state.shard += 1;
state.previous_init_addr_bits = record.public_values.previous_init_addr_bits;
state.last_init_addr_bits = record.public_values.last_init_addr_bits;
state.previous_finalize_addr_bits =
record.public_values.previous_finalize_addr_bits;
state.last_finalize_addr_bits = record.public_values.last_finalize_addr_bits;
state.start_pc = state.next_pc;
record.public_values = state;
}
records.append(&mut deferred);

records_tx.send(records).unwrap();
}
records_tx.send(records).unwrap();
}
});
drop(records_tx);
let shard_proofs = shard_proofs.join().unwrap();

Expand Down Expand Up @@ -381,7 +407,7 @@ pub fn run_test_io<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
program: Program,
inputs: SP1Stdin,
) -> Result<SP1PublicValues, crate::stark::MachineVerificationError<BabyBearPoseidon2>> {
let runtime = tracing::info_span!("runtime.run(...)").in_scope(|| {
let runtime = tracing::debug_span!("runtime.run(...)").in_scope(|| {
let mut runtime = Runtime::new(program, SP1CoreOpts::default());
runtime.write_vecs(&inputs.buffer);
runtime.run().unwrap();
Expand All @@ -398,7 +424,7 @@ pub fn run_test<P: MachineProver<BabyBearPoseidon2, RiscvAir<BabyBear>>>(
crate::stark::MachineProof<BabyBearPoseidon2>,
crate::stark::MachineVerificationError<BabyBearPoseidon2>,
> {
let runtime = tracing::info_span!("runtime.run(...)").in_scope(|| {
let runtime = tracing::debug_span!("runtime.run(...)").in_scope(|| {
let mut runtime = Runtime::new(program, SP1CoreOpts::default());
runtime.run().unwrap();
runtime
Expand Down Expand Up @@ -483,8 +509,7 @@ fn trace_checkpoint(
// We already passed the deferred proof verifier when creating checkpoints, so the proofs were
// already verified. So here we use a noop verifier to not print any warnings.
runtime.subproof_verifier = Arc::new(NoOpSubproofVerifier);
let (events, _) =
tracing::debug_span!("runtime.trace").in_scope(|| runtime.execute_record().unwrap());
let (events, _) = runtime.execute_record().unwrap();
(events, runtime.report)
}

Expand Down
Loading