Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: result instead of exit(1) on trap in recursion #1089

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,15 +443,10 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
let proofs = inputs
.into_par_iter()
.map(|input| {
let proof = self.compress_machine_proof(
input,
&self.recursion_program,
&self.rec_pk,
opts,
);
(proof, ReduceProgramType::Core)
self.compress_machine_proof(input, &self.recursion_program, &self.rec_pk, opts)
.map(|p| (p, ReduceProgramType::Core))
})
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()?;
reduce_proofs.extend(proofs);
}

Expand All @@ -460,15 +455,15 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
let proofs = inputs
.into_par_iter()
.map(|input| {
let proof = self.compress_machine_proof(
self.compress_machine_proof(
input,
&self.deferred_program,
&self.deferred_pk,
opts,
);
(proof, ReduceProgramType::Deferred)
)
.map(|p| (p, ReduceProgramType::Deferred))
})
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()?;
reduce_proofs.extend(proofs);
}

Expand All @@ -482,7 +477,7 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
let batched_compress_inputs =
compress_inputs.chunks(shard_batch_size).collect::<Vec<_>>();
reduce_proofs = batched_compress_inputs
.into_iter()
.into_par_iter()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems kind of free? Or is there a reason we chose to not do this concurrently? just lmk and I'll remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no, good catch. That looks good :)

.flat_map(|batches| {
batches
.par_iter()
Expand All @@ -498,17 +493,17 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
is_complete,
};

let proof = self.compress_machine_proof(
self.compress_machine_proof(
input,
&self.compress_program,
&self.compress_pk,
opts,
);
(proof, ReduceProgramType::Reduce)
)
.map(|p| (p, ReduceProgramType::Reduce))
})
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
.collect::<Result<Vec<_>, _>>()?;

if reduce_proofs.len() == 1 {
break;
Expand All @@ -528,7 +523,7 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
program: &RecursionProgram<BabyBear>,
pk: &StarkProvingKey<InnerSC>,
opts: SP1ProverOpts,
) -> ShardProof<InnerSC> {
) -> Result<ShardProof<InnerSC>, SP1RecursionProverError> {
let mut runtime = RecursionRuntime::<Val<InnerSC>, Challenge<InnerSC>, _>::new(
program,
self.compress_prover.config().perm.clone(),
Expand All @@ -538,11 +533,15 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
witness_stream.extend(input.write());

runtime.witness_stream = witness_stream.into();
runtime.run();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more canonical way to do this would be to do runtime.run().map_err(..)?;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated


runtime
.run()
.map_err(|e| SP1RecursionProverError::RuntimeError(e.to_string()))?;
runtime.print_stats();

let mut recursive_challenger = self.compress_prover.config().challenger();
self.compress_prover
let proof = self
.compress_prover
.prove(
pk,
vec![runtime.record],
Expand All @@ -552,7 +551,9 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
.unwrap()
.shard_proofs
.pop()
.unwrap()
.unwrap();

Ok(proof)
}

/// Wrap a reduce proof into a STARK proven over a SNARK-friendly field.
Expand All @@ -579,7 +580,11 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
witness_stream.extend(input.write());

runtime.witness_stream = witness_stream.into();
runtime.run();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

runtime
.run()
.map_err(|e| SP1RecursionProverError::RuntimeError(e.to_string()))?;

runtime.print_stats();
tracing::debug!("Compress program executed successfully");

Expand Down Expand Up @@ -623,7 +628,11 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
witness_stream.extend(input.write());

runtime.witness_stream = witness_stream.into();
runtime.run();

runtime
.run()
.map_err(|e| SP1RecursionProverError::RuntimeError(e.to_string()))?;

runtime.print_stats();
tracing::debug!("Wrap program executed successfully");

Expand Down
5 changes: 4 additions & 1 deletion prover/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,7 @@ pub enum SP1ReduceProofWrapper {
}

#[derive(Error, Debug)]
pub enum SP1RecursionProverError {}
pub enum SP1RecursionProverError {
#[error("Runtime error: {0}")]
RuntimeError(String),
}
2 changes: 1 addition & 1 deletion recursion/circuit/src/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ mod tests {
let program = basic_program::<F>();
let config = SC::new();
let mut runtime = Runtime::<F, EF, DiffusionMatrixBabyBear>::new_no_perm(&program);
runtime.run();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unwrap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

runtime.run().unwrap();
let machine = A::machine(config);
let prover = DefaultProver::new(machine);
let (pk, vk) = prover.setup(&program);
Expand Down
2 changes: 1 addition & 1 deletion recursion/compiler/examples/fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() {

let config = SC::new();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unwrap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

runtime.run().unwrap();

// let machine = RecursionAir::machine(config);
// let (pk, vk) = machine.setup(&program);
Expand Down
4 changes: 2 additions & 2 deletions recursion/compiler/src/ir/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ mod tests {
let program = builder.compile_program();

let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrwap

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

runtime.run();
runtime.run().unwrap();
}

#[test]
Expand Down Expand Up @@ -363,6 +363,6 @@ mod tests {
let program = builder.compile_program();

let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}
}
2 changes: 1 addition & 1 deletion recursion/compiler/tests/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ fn test_compiler_arithmetic() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
runtime.print_stats();
}
2 changes: 1 addition & 1 deletion recursion/compiler/tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,5 @@ fn test_compiler_array() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}
4 changes: 2 additions & 2 deletions recursion/compiler/tests/conditionals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn test_compiler_conditionals() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}

#[test]
Expand Down Expand Up @@ -93,5 +93,5 @@ fn test_compiler_conditionals_v2() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}
10 changes: 5 additions & 5 deletions recursion/compiler/tests/for_loops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fn test_compiler_for_loops() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}

#[test]
Expand Down Expand Up @@ -86,7 +86,7 @@ fn test_compiler_nested_array_loop() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}

#[test]
Expand Down Expand Up @@ -170,7 +170,7 @@ fn test_compiler_break() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}

#[test]
Expand All @@ -197,7 +197,7 @@ fn test_compiler_step_by() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}

#[test]
Expand Down Expand Up @@ -225,5 +225,5 @@ fn test_compiler_bneinc() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}
2 changes: 1 addition & 1 deletion recursion/compiler/tests/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ fn test_io() {
vec![F::one().into(), F::one().into(), F::two().into()],
]
.into();
runtime.run();
runtime.run().unwrap();
}
2 changes: 1 addition & 1 deletion recursion/compiler/tests/lt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ fn test_compiler_less_than() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}
6 changes: 3 additions & 3 deletions recursion/compiler/tests/poseidon2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn test_compiler_poseidon2_permute() {
let program = builder.compile_program();

let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
println!(
"The program executed successfully, number of cycles: {}",
runtime.clk.as_canonical_u32() / 4
Expand Down Expand Up @@ -115,7 +115,7 @@ fn test_compiler_poseidon2_hash() {
let program = builder.compile_program();

let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
println!(
"The program executed successfully, number of cycles: {}",
runtime.clk.as_canonical_u32() / 4
Expand Down Expand Up @@ -151,7 +151,7 @@ fn test_compiler_poseidon2_hash_v2() {
let program = builder.compile_program();

let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
println!(
"The program executed successfully, number of cycles: {}",
runtime.clk.as_canonical_u32() / 4
Expand Down
2 changes: 1 addition & 1 deletion recursion/compiler/tests/public_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ fn test_compiler_public_values() {

let config = SC::default();
let mut runtime = Runtime::<F, EF, _>::new(&program, config.perm.clone());
runtime.run();
runtime.run().unwrap();
}
Loading
Loading