Skip to content
24 changes: 22 additions & 2 deletions crates/core/executor/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ pub struct SP1Context<'a> {
/// The maximum number of cpu cycles to use for execution.
pub max_cycles: Option<u64>,

/// The maximum amount of prover gas to use for execution.
pub max_gas: Option<u64>,

/// Deferred proof verification.
pub deferred_proof_verification: bool,

Expand All @@ -48,6 +51,7 @@ pub struct SP1ContextBuilder<'a> {
hook_registry_entries: Vec<(u32, BoxedHook<'a>)>,
subproof_verifier: Option<&'a dyn SubproofVerifier>,
max_cycles: Option<u64>,
max_gas: Option<u64>,
deferred_proof_verification: bool,
calculate_gas: bool,
io_options: IoOptions<'a>,
Expand All @@ -60,6 +64,7 @@ impl Default for SP1ContextBuilder<'_> {
hook_registry_entries: Vec::new(),
subproof_verifier: None,
max_cycles: None,
max_gas: None,
// Always verify deferred proofs by default.
deferred_proof_verification: true,
calculate_gas: true,
Expand Down Expand Up @@ -117,12 +122,14 @@ impl<'a> SP1ContextBuilder<'a> {

let subproof_verifier = take(&mut self.subproof_verifier);
let cycle_limit = take(&mut self.max_cycles);
let gas_limit = take(&mut self.max_gas);
let deferred_proof_verification = take(&mut self.deferred_proof_verification);
let calculate_gas = take(&mut self.calculate_gas);
SP1Context {
hook_registry,
subproof_verifier,
max_cycles: cycle_limit,
max_gas: gas_limit,
deferred_proof_verification,
calculate_gas,
io_options: take(&mut self.io_options),
Expand Down Expand Up @@ -183,6 +190,13 @@ impl<'a> SP1ContextBuilder<'a> {
self
}

/// Set the maximum amount of gas to use for execution.
/// `report.gas` will be less than or equal to `max_gas`.
pub fn max_gas(&mut self, max_gas: u64) -> &mut Self {
self.max_gas = Some(max_gas);
self
}

/// Set the deferred proof verification flag.
pub fn set_deferred_proof_verification(&mut self, value: bool) -> &mut Self {
self.deferred_proof_verification = value;
Expand Down Expand Up @@ -232,11 +246,17 @@ mod tests {

#[test]
fn defaults() {
let SP1Context { hook_registry, subproof_verifier, max_cycles: cycle_limit, .. } =
SP1Context::builder().build();
let SP1Context {
hook_registry,
subproof_verifier,
max_cycles: cycle_limit,
max_gas: gas_limit,
..
} = SP1Context::builder().build();
assert!(hook_registry.is_none());
assert!(subproof_verifier.is_none());
assert!(cycle_limit.is_none());
assert!(gas_limit.is_none());
}

#[test]
Expand Down
45 changes: 41 additions & 4 deletions crates/core/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ pub const UNUSED_PC: u32 = 1;
/// The maximum number of instructions in a program.
pub const MAX_PROGRAM_SIZE: usize = 1 << 22;

/// A thread-safe boxed closure that calculates gas consumption from `RecordEstimator`.
pub type GasCalculator = Box<dyn FnMut(&RecordEstimator) -> u64 + Send>;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
/// Whether to verify deferred proofs during execution.
pub enum DeferredProofVerification {
Expand Down Expand Up @@ -132,6 +135,15 @@ pub struct Executor<'a> {
/// The maximum number of cpu cycles to use for execution.
pub max_cycles: Option<u64>,

/// The maximum amount of gas to use for execution.
pub max_gas: Option<u64>,

/// The total gas consumed so far during execution.
pub gas_used: u64,

/// The gas calculator function used to calculate gas from execution records.
pub gas_calculator: Option<GasCalculator>,

/// The current trace of the execution that is being collected.
pub record: Box<ExecutionRecord>,

Expand Down Expand Up @@ -238,6 +250,10 @@ pub enum ExecutionError {
#[error("exceeded cycle limit of {0}")]
ExceededCycleLimit(u64),

/// The execution failed with an exceeded gas limit.
#[error("exceeded gas limit of {0}")]
ExceededGasLimit(u64),

/// The execution failed because the syscall was called in unconstrained mode.
#[error("syscall called in unconstrained mode")]
InvalidSyscallUsage(u64),
Expand Down Expand Up @@ -323,6 +339,8 @@ impl<'a> Executor<'a> {
let costs: HashMap<RiscvAirId, usize> =
costs.into_iter().map(|(k, v)| (RiscvAirId::from_str(&k).unwrap(), v)).collect();

tracing::info!("max_cycles: {:?}, max_gas: {:?}", context.max_cycles, context.max_gas);

Self {
record: Box::new(record),
records: vec![],
Expand All @@ -349,6 +367,9 @@ impl<'a> Executor<'a> {
hook_registry,
opts,
max_cycles: context.max_cycles,
max_gas: context.max_gas,
gas_used: 0,
gas_calculator: None,
deferred_proof_verification: context.deferred_proof_verification.into(),
memory_checkpoint: Memory::default(),
uninitialized_memory_checkpoint: Memory::default(),
Expand Down Expand Up @@ -1801,7 +1822,7 @@ impl<'a> Executor<'a> {
}

if cpu_exit || !shape_match_found {
self.bump_record();
self.bump_record()?;
self.state.current_shard += 1;
self.state.clk = 0;
}
Expand All @@ -1825,7 +1846,22 @@ impl<'a> Executor<'a> {
}

/// Bump the record.
pub fn bump_record(&mut self) {
pub fn bump_record(&mut self) -> Result<(), ExecutionError> {
if let (Some(ref mut gas_calculator), Some(ref estimator)) =
(&mut self.gas_calculator, &self.record_estimator)
{
let shard_gas = gas_calculator(estimator);
self.gas_used += shard_gas;

tracing::info!("[bump_record] gas_used: {}", self.gas_used);

if let Some(gas_limit) = self.max_gas {
if self.gas_used > gas_limit {
return Err(ExecutionError::ExceededGasLimit(gas_limit));
}
}
}

if let Some(estimator) = &mut self.record_estimator {
self.local_counts.local_mem = std::mem::take(&mut estimator.current_local_mem);
Self::estimate_riscv_event_counts(
Expand All @@ -1852,6 +1888,7 @@ impl<'a> Executor<'a> {
let public_values = removed_record.public_values;
self.record.public_values = public_values;
self.records.push(removed_record);
Ok(())
}

/// Execute up to `self.shard_batch_size` cycles, returning the events emitted and whether the
Expand Down Expand Up @@ -2062,7 +2099,7 @@ impl<'a> Executor<'a> {
self.postprocess();

// Push the remaining execution record with memory initialize & finalize events.
self.bump_record();
self.bump_record()?;

// Flush stdout and stderr.
if let Some(ref mut w) = self.io_options.stdout {
Expand All @@ -2080,7 +2117,7 @@ impl<'a> Executor<'a> {

// Push the remaining execution record, if there are any CPU events.
if !self.record.cpu_events.is_empty() {
self.bump_record();
self.bump_record()?;
}

// Set the global public values for all shards.
Expand Down
2 changes: 1 addition & 1 deletion crates/core/machine/src/shape/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ const MAXIMAL_SHAPES: &[u8] = include_bytes!("maximal_shapes.json");
const SMALL_SHAPES: &[u8] = include_bytes!("small_shapes.json");

/// A configuration for what shapes are allowed to be used by the prover.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct CoreShapeConfig<F: PrimeField32> {
partial_preprocessed_shapes: ShapeCluster<RiscvAirId>,
partial_core_shapes: BTreeMap<usize, Vec<ShapeCluster<RiscvAirId>>>,
Expand Down
23 changes: 12 additions & 11 deletions crates/core/machine/src/utils/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pub fn prove_core_stream<SC: StarkGenericConfig, P: MachineProver<SC, RiscvAir<S
proof_tx: Sender<ShardProof<SC>>,
shape_and_done_tx: Sender<(OrderedShape, bool)>,
malicious_trace_pv_generator: Option<MaliciousTracePVGeneratorType<SC::Val, P>>, /* This is used for failure test cases that generate malicious traces and public values. */
gas_calculator: Option<Box<dyn FnOnce(&RecordEstimator) -> Result<u64, Box<dyn Error>> + '_>>,
gas_calculator: Option<Box<dyn FnMut(&RecordEstimator) -> Result<u64, Box<dyn Error>> + Send>>,
) -> Result<(Vec<u8>, u64), SP1CoreProverError>
where
SC::Val: PrimeField32,
Expand All @@ -116,9 +116,15 @@ where
let (proof, vk) = proof.clone();
runtime.write_proof(proof, vk);
}
// Set the record estimator to collect data for gas calculation.
if gas_calculator.is_some() {
// Set up gas calculator for shard-based gas calculation and record estimator.
if let Some(mut calculator) = gas_calculator {
runtime.record_estimator = Some(Box::default());
runtime.gas_calculator = Some(Box::new(move |estimator: &RecordEstimator| -> u64 {
calculator(estimator).unwrap_or_else(|e| {
tracing::error!("Gas calculation failed during proving: {}", e);
0
})
}));
}

#[cfg(feature = "debug")]
Expand Down Expand Up @@ -519,7 +525,6 @@ where

// Wait until the checkpoint generator handle has fully finished.
let runtime = checkpoint_generator_handle.join().unwrap().unwrap();
let gas = gas_calculator.map(|calc| calc(runtime.record_estimator.as_ref().unwrap()));
let public_values_stream = runtime.state.public_values_stream;

// Wait until the records and traces have been fully generated for phase 2.
Expand All @@ -536,13 +541,9 @@ where
report_aggregate.total_syscall_count(),
report_aggregate.touched_memory_addresses,
);
match gas {
Some(Ok(gas)) => {
tracing::debug!("execution report (gas): {}", gas);
report_aggregate.gas = Some(gas);
}
Some(Err(err)) => tracing::error!("Encountered error while calculating gas: {}", err),
None => (),
if runtime.gas_used > 0 {
tracing::debug!("execution report (gas): {}", runtime.gas_used);
report_aggregate.gas = Some(runtime.gas_used);
}

// Print the opcode and syscall count tables like `du`: sorted by count (descending) and
Expand Down
6 changes: 3 additions & 3 deletions crates/prover/src/gas/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ pub fn fit_records_to_shapes<'a, F: PrimeField32>(
})
}

struct CoreShard<'a> {
shard_index: u32,
record: &'a EnumMap<RiscvAirId, u64>,
pub struct CoreShard<'a> {
pub shard_index: u32,
pub record: &'a EnumMap<RiscvAirId, u64>,
}

impl Shapeable for CoreShard<'_> {
Expand Down
84 changes: 68 additions & 16 deletions crates/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,11 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
Ok(program)
}

fn get_gas_calculator(
/// Get gas calculator for post-execution report gas calculation.
///
/// Processes all execution data (core, precompile, memory) after execution completes.
/// Provides optimal gas calculation by fitting all shards together.
fn get_post_execution_gas_calculator(
&self,
preprocessed_shape: Shape<RiscvAirId>,
split_opts: SplitOpts,
Expand All @@ -316,6 +320,41 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
}
}

/// Get gas calculator for executor shard-based calculation.
///
/// Processes only the current core shard during execution for gas limit checks.
/// Ignores precompile/memory data and fits shards individually, which may be less optimal.
fn get_shard_gas_calculator(
&self,
preprocessed_shape: Shape<RiscvAirId>,
) -> Box<dyn FnMut(&RecordEstimator) -> u64 + Send> {
let core_shape_config = self.core_shape_config.clone().unwrap();
Box::new(move |estimator: &RecordEstimator| -> u64 {
if let Some(last_shard) = estimator.core_records.last() {
let core_shard = gas::CoreShard {
shard_index: (estimator.core_records.len() - 1) as u32,
record: last_shard,
};
let raw_gas = match core_shape_config.find_shape(&core_shard) {
Ok(mut shape) => {
shape.extend(preprocessed_shape.iter().map(|(k, v)| (*k, *v)));
gas::predict(enum_map::EnumMap::from_iter(shape).as_array())
}
Err(e) => {
tracing::error!("Shape fitting failed for current shard: {}", e);
0.0
}
};
gas::final_transform(raw_gas).unwrap_or_else(|e| {
tracing::error!("Gas calculation failed: {}", e);
0
})
} else {
0
}
})
}

/// Execute an SP1 program with the specified inputs.
#[instrument(name = "execute", level = "info", skip_all)]
pub fn execute<'a>(
Expand All @@ -326,7 +365,7 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
) -> Result<(SP1PublicValues, [u8; 32], ExecutionReport), ExecutionError> {
context.subproof_verifier = Some(self);

let calculate_gas = context.calculate_gas;
let calculate_gas = context.calculate_gas || context.max_gas.is_some();

let (opts, program) = if calculate_gas {
(gas::GAS_OPTS, self.get_program(elf).unwrap())
Expand All @@ -343,6 +382,10 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
config.maximal_core_shapes(opts.shard_size.ilog2() as usize).into_iter().collect()
});
runtime.record_estimator = Some(Box::default());

// Set up gas calculator for shard-based gas calculation.
runtime.gas_calculator =
Some(self.get_shard_gas_calculator(preprocessed_shape.clone().unwrap()));
}

runtime.maybe_setup_profiler(elf);
Expand All @@ -354,13 +397,21 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
runtime.run_fast()?;

if calculate_gas {
let gas = self.get_gas_calculator(preprocessed_shape.unwrap(), opts.split_opts)(
runtime.record_estimator.as_ref().unwrap(),
);
runtime.report.gas = gas
.inspect(|g| tracing::info!("gas: {}", g))
.inspect_err(|e| tracing::error!("Encountered error while calculating gas: {}", e))
.ok();
if runtime.gas_used > 0 {
runtime.report.gas = Some(runtime.gas_used);
tracing::info!("gas: {}", runtime.gas_used);
} else {
let gas = self.get_post_execution_gas_calculator(
preprocessed_shape.unwrap(),
opts.split_opts,
)(runtime.record_estimator.as_ref().unwrap());
runtime.report.gas = gas
.inspect(|g| tracing::info!("gas: {}", g))
.inspect_err(|e| {
tracing::error!("Encountered error while calculating gas: {}", e)
})
.ok();
}
}

let mut committed_value_digest = [0u8; 32];
Expand Down Expand Up @@ -413,13 +464,13 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
// This ensures that the gas number is consistent between `execute` and `prove_core`.
// This behavior is undocumented because it is confusing and not very useful.
//
// If `context.calculate_gas` is set, we use the logic from the `gas` module
// after checkpoint execution to print gas as part of the execution report.
// If `context.calculate_gas` or `context.max_gas` is set, we use the logic from the
// `gas` module after checkpoint execution to calculate gas.
#[allow(clippy::type_complexity)]
let gas_calculator = (context.calculate_gas
let gas_calculator = ((context.calculate_gas || context.max_gas.is_some())
&& std::env::var("SP1_FORCE_GAS").is_ok())
.then(
|| -> Box<dyn FnOnce(&RecordEstimator) -> Result<u64, Box<dyn Error>> + '_> {
|| -> Box<dyn FnMut(&RecordEstimator) -> Result<u64, Box<dyn Error>> + Send> {
tracing::info!("Forcing calculation of gas while proving.");
if opts.core_opts == gas::GAS_OPTS {
tracing::info!(
Expand All @@ -432,9 +483,10 @@ impl<C: SP1ProverComponents> SP1Prover<C> {
);
}
let preprocessed_shape = program.preprocessed_shape.clone().unwrap();
Box::new(
self.get_gas_calculator(preprocessed_shape, opts.core_opts.split_opts),
)
let mut calculator = self.get_shard_gas_calculator(preprocessed_shape);
Box::new(move |estimator: &RecordEstimator| -> Result<u64, Box<dyn Error>> {
Ok(calculator(estimator))
})
},
);

Expand Down
Loading
Loading