Skip to content

Commit

Permalink
feat: execute() exposes ExecutionReport (#847)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattstam authored May 30, 2024
1 parent 507b67c commit 99effb1
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 38 deletions.
59 changes: 45 additions & 14 deletions core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub use utils::*;

use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{Display, Formatter, Result as FmtResult};
use std::fs::File;
use std::io::BufWriter;
use std::io::Write;
Expand Down Expand Up @@ -82,23 +83,53 @@ pub struct Runtime {

pub emit_events: bool,

/// Report of instruction calls.
pub report: InstructionReport,
/// Report of the program execution.
pub report: ExecutionReport,

/// Whether we should write to the report.
pub should_report: bool,
}

#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct InstructionReport {
pub instruction_counts: HashMap<Opcode, u32>,
pub syscall_counts: HashMap<SyscallCode, u32>,
pub struct ExecutionReport {
pub instruction_counts: HashMap<Opcode, u64>,
pub syscall_counts: HashMap<SyscallCode, u64>,
}

impl InstructionReport {
pub fn total_instruction_count(&self) -> u32 {
impl ExecutionReport {
pub fn total_instruction_count(&self) -> u64 {
self.instruction_counts.values().sum()
}

pub fn total_syscall_count(&self) -> u64 {
self.syscall_counts.values().sum()
}
}

impl Display for ExecutionReport {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
writeln!(f, "Instruction Counts:")?;
let mut sorted_instructions = self.instruction_counts.iter().collect::<Vec<_>>();

// Sort instructions by opcode name
sorted_instructions.sort_by_key(|&(opcode, _)| opcode.to_string());
for (opcode, count) in sorted_instructions {
writeln!(f, " {}: {}", opcode, count)?;
}
writeln!(f, "Total Instructions: {}", self.total_instruction_count())?;

writeln!(f, "Syscall Counts:")?;
let mut sorted_syscalls = self.syscall_counts.iter().collect::<Vec<_>>();

// Sort syscalls by syscall name
sorted_syscalls.sort_by_key(|&(syscall, _)| format!("{:?}", syscall));
for (syscall, count) in sorted_syscalls {
writeln!(f, " {}: {}", syscall, count)?;
}
writeln!(f, "Total Syscall Count: {}", self.total_syscall_count())?;

Ok(())
}
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -555,7 +586,7 @@ impl Runtime {
let mut memory_store_value: Option<u32> = None;
self.memory_accesses = MemoryAccessRecord::default();

if self.should_report {
if self.should_report && !self.unconstrained {
self.report
.instruction_counts
.entry(instruction.opcode)
Expand Down Expand Up @@ -777,7 +808,7 @@ impl Runtime {
b = self.rr(Register::X10, MemoryAccessPosition::B);
let syscall = SyscallCode::from_u32(syscall_id);

if self.should_report {
if self.should_report && !self.unconstrained {
self.report
.syscall_counts
.entry(syscall)
Expand Down Expand Up @@ -990,18 +1021,18 @@ impl Runtime {
tracing::info!("starting execution");
}

pub fn run_untraced(&mut self) -> Result<&InstructionReport, ExecutionError> {
pub fn run_untraced(&mut self) -> Result<(), ExecutionError> {
self.emit_events = false;
self.should_report = true;
while !self.execute()? {}
Ok(&self.report)
Ok(())
}

pub fn run(&mut self) -> Result<&InstructionReport, ExecutionError> {
pub fn run(&mut self) -> Result<(), ExecutionError> {
self.emit_events = true;
self.should_report = true;
while !self.execute()? {}
Ok(&self.report)
Ok(())
}

pub fn dry_run(&mut self) {
Expand Down Expand Up @@ -1156,7 +1187,7 @@ pub mod tests {
assert_eq!(runtime.report, {
use super::Opcode::*;
use super::SyscallCode::*;
super::InstructionReport {
super::ExecutionReport {
instruction_counts: [
(LB, 10723),
(DIVU, 6),
Expand Down
9 changes: 8 additions & 1 deletion core/src/runtime/syscall.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;

use strum_macros::EnumIter;
Expand Down Expand Up @@ -28,7 +29,7 @@ use crate::{runtime::ExecutionRecord, runtime::MemoryReadRecord, runtime::Memory
/// - The second byte is 0/1 depending on whether the syscall has a separate table. This is used
/// in the CPU table to determine whether to lookup the syscall using the syscall interaction.
/// - The third byte is the number of additional cycles the syscall uses.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, EnumIter)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, EnumIter, Ord, PartialOrd)]
#[allow(non_camel_case_types)]
pub enum SyscallCode {
/// Halts the program.
Expand Down Expand Up @@ -149,6 +150,12 @@ impl SyscallCode {
}
}

impl fmt::Display for SyscallCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}

pub trait Syscall: Send + Sync {
/// Execute the syscall and return the resulting value of register a0. `arg1` and `arg2` are the
/// values in registers X10 and X11, respectively. While not a hard requirement, the convention
Expand Down
2 changes: 1 addition & 1 deletion examples/fibonacci/script/bin/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() {

// Only execute the program and get a `SP1PublicValues` object.
let client = ProverClient::new();
let mut public_values = client.execute(ELF, stdin).unwrap();
let (mut public_values, _) = client.execute(ELF, stdin).unwrap();

println!("generated proof");

Expand Down
12 changes: 9 additions & 3 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::prelude::*;
use sp1_core::air::{PublicValues, Word};
pub use sp1_core::io::{SP1PublicValues, SP1Stdin};
use sp1_core::runtime::{ExecutionError, Runtime};
use sp1_core::runtime::{ExecutionError, ExecutionReport, Runtime};
use sp1_core::stark::{Challenge, StarkProvingKey};
use sp1_core::stark::{Challenger, MachineVerificationError};
use sp1_core::utils::{SP1CoreOpts, DIGEST_SIZE};
Expand Down Expand Up @@ -213,7 +213,10 @@ impl SP1Prover {

/// Generate a proof of an SP1 program with the specified inputs.
#[instrument(name = "execute", level = "info", skip_all)]
pub fn execute(elf: &[u8], stdin: &SP1Stdin) -> Result<SP1PublicValues, ExecutionError> {
pub fn execute(
elf: &[u8],
stdin: &SP1Stdin,
) -> Result<(SP1PublicValues, ExecutionReport), ExecutionError> {
let program = Program::from(elf);
let opts = SP1CoreOpts::default();
let mut runtime = Runtime::new(program, opts);
Expand All @@ -222,7 +225,10 @@ impl SP1Prover {
runtime.write_proof(proof.clone(), vkey.clone());
}
runtime.run_untraced()?;
Ok(SP1PublicValues::from(&runtime.state.public_values_stream))
Ok((
SP1PublicValues::from(&runtime.state.public_values_stream),
runtime.report,
))
}

/// Generate shard proofs which split up and prove the valid execution of a RISC-V program with
Expand Down
15 changes: 11 additions & 4 deletions sdk/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ use std::{env, fmt::Debug, fs::File, path::Path};
use anyhow::{Ok, Result};
pub use provers::{LocalProver, MockProver, NetworkProver, Prover};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sp1_core::stark::{MachineVerificationError, ShardProof};
use sp1_core::{
runtime::ExecutionReport,
stark::{MachineVerificationError, ShardProof},
};
pub use sp1_prover::{
CoreSC, HashableKey, InnerSC, OuterSC, PlonkBn254Proof, SP1Prover, SP1ProvingKey,
SP1PublicValues, SP1Stdin, SP1VerifyingKey,
Expand Down Expand Up @@ -151,7 +154,7 @@ impl ProverClient {

/// Executes the given program on the given input (without generating a proof).
///
/// Returns the public values of the program after it has been executed.
/// Returns the public values and execution report of the program after it has been executed.
///
///
/// ### Examples
Expand All @@ -169,9 +172,13 @@ impl ProverClient {
/// stdin.write(&10usize);
///
/// // Execute the program on the inputs.
/// let public_values = client.execute(elf, stdin).unwrap();
/// let (public_values, report) = client.execute(elf, stdin).unwrap();
/// ```
pub fn execute(&self, elf: &[u8], stdin: SP1Stdin) -> Result<SP1PublicValues> {
pub fn execute(
&self,
elf: &[u8],
stdin: SP1Stdin,
) -> Result<(SP1PublicValues, ExecutionReport)> {
Ok(SP1Prover::execute(elf, &stdin)?)
}

Expand Down
4 changes: 2 additions & 2 deletions sdk/src/provers/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Prover for MockProver {
}

fn prove(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result<SP1Proof> {
let public_values = SP1Prover::execute(&pk.elf, &stdin)?;
let (public_values, _) = SP1Prover::execute(&pk.elf, &stdin)?;
Ok(SP1ProofWithPublicValues {
proof: vec![],
stdin,
Expand All @@ -55,7 +55,7 @@ impl Prover for MockProver {
}

fn prove_plonk(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result<SP1PlonkBn254Proof> {
let public_values = SP1Prover::execute(&pk.elf, &stdin)?;
let (public_values, _) = SP1Prover::execute(&pk.elf, &stdin)?;
Ok(SP1PlonkBn254Proof {
proof: PlonkBn254Proof {
public_inputs: [
Expand Down
26 changes: 13 additions & 13 deletions sdk/src/provers/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ use crate::{
use crate::{SP1CompressedProof, SP1PlonkBn254Proof, SP1Proof, SP1ProvingKey, SP1VerifyingKey};
use anyhow::{Context, Result};
use serde::de::DeserializeOwned;
use sp1_core::runtime::{Program, Runtime};
use sp1_core::utils::SP1CoreOpts;
use sp1_prover::utils::block_on;
use sp1_prover::{SP1Prover, SP1Stdin};
use tokio::{runtime, time::sleep};
Expand Down Expand Up @@ -42,18 +40,20 @@ impl NetworkProver {
mode: ProofMode,
) -> Result<P> {
let client = &self.client;
// Execute the runtime before creating the proof request.
let program = Program::from(elf);
let opts = SP1CoreOpts::default();
let mut runtime = Runtime::new(program, opts);
runtime.write_vecs(&stdin.buffer);
for (proof, vkey) in stdin.proofs.iter() {
runtime.write_proof(proof.clone(), vkey.clone());

let skip_simulation = env::var("SKIP_SIMULATION")
.map(|val| val == "true")
.unwrap_or(false);

if !skip_simulation {
let (_, report) = SP1Prover::execute(elf, &stdin)?;
log::info!(
"Simulation complete, cycles: {}",
report.total_instruction_count()
);
} else {
log::info!("Skipping simulation");
}
runtime
.run_untraced()
.context("Failed to execute program")?;
log::info!("Simulation complete, cycles: {}", runtime.state.global_clk);

let proof_id = client.create_proof(elf, &stdin, mode).await?;
log::info!("Created {}", proof_id);
Expand Down

0 comments on commit 99effb1

Please sign in to comment.