Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: optimize ExecutionReport and create_alu_lookup_id iff mode is trace #1480

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/core/executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ hex = "0.4.3"
bytemuck = "1.16.3"
tiny-keccak = { version = "2.0.2", features = ["keccak"] }
vec_map = { version = "0.8.2", features = ["serde"] }
enum-map = { version = "2.7.3", features = ["serde"] }

[dev-dependencies]
sp1-zkvm = { workspace = true }
Expand Down
20 changes: 8 additions & 12 deletions crates/core/executor/src/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,23 +696,19 @@ impl<'a> Executor<'a> {
if self.executor_mode == ExecutorMode::Trace {
self.memory_accesses = MemoryAccessRecord::default();
}
let lookup_id = if self.executor_mode == ExecutorMode::Simple {
LookupId::default()
} else {
let lookup_id = if self.executor_mode == ExecutorMode::Trace {
create_alu_lookup_id()
};
let syscall_lookup_id = if self.executor_mode == ExecutorMode::Simple {
LookupId::default()
} else {
LookupId::default()
};
let syscall_lookup_id = if self.executor_mode == ExecutorMode::Trace {
create_alu_lookup_id()
} else {
LookupId::default()
};

if self.print_report && !self.unconstrained {
self.report
.opcode_counts
.entry(instruction.opcode)
.and_modify(|c| *c += 1)
.or_insert(1);
self.report.opcode_counts[instruction.opcode] += 1;
}

match instruction.opcode {
Expand Down Expand Up @@ -930,7 +926,7 @@ impl<'a> Executor<'a> {
let syscall = SyscallCode::from_u32(syscall_id);

if self.print_report && !self.unconstrained {
self.report.syscall_counts.entry(syscall).and_modify(|c| *c += 1).or_insert(1);
self.report.syscall_counts[syscall] += 1;
}

// `hint_slice` is allowed in unconstrained mode since it is used to write the hint.
Expand Down
5 changes: 4 additions & 1 deletion crates/core/executor/src/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use std::fmt::Display;

use enum_map::Enum;
use p3_field::Field;
use serde::{Deserialize, Serialize};

Expand All @@ -20,7 +21,9 @@ use serde::{Deserialize, Serialize};
/// Refer to the "RV32I Reference Card" [here](https://github.com/johnwinans/rvalp/releases) for
/// more details.
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord, Enum,
)]
pub enum Opcode {
/// rd ← rs1 + rs2, pc ← pc + 4
ADD = 0,
Expand Down
27 changes: 12 additions & 15 deletions crates/core/executor/src/report.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
use std::{
collections::{hash_map::Entry, HashMap},
fmt::{Display, Formatter, Result as FmtResult},
hash::Hash,
ops::{Add, AddAssign},
};

use enum_map::{EnumArray, EnumMap};
use hashbrown::HashMap;

use crate::{events::sorted_table_lines, syscalls::SyscallCode, Opcode};

/// An execution report.
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct ExecutionReport {
/// The opcode counts.
pub opcode_counts: HashMap<Opcode, u64>,
pub opcode_counts: Box<EnumMap<Opcode, u64>>,
/// The syscall counts.
pub syscall_counts: HashMap<SyscallCode, u64>,
pub syscall_counts: Box<EnumMap<SyscallCode, u64>>,
/// The cycle tracker counts.
pub cycle_tracker: HashMap<String, u64>,
/// The unique memory address counts.
Expand All @@ -35,24 +36,20 @@ impl ExecutionReport {
}

/// Combines two `HashMap`s together. If a key is in both maps, the values are added together.
fn hashmap_add_assign<K, V>(lhs: &mut HashMap<K, V>, rhs: HashMap<K, V>)
fn counts_add_assign<K, V>(lhs: &mut EnumMap<K, V>, rhs: EnumMap<K, V>)
where
K: Eq + Hash,
K: EnumArray<V>,
V: AddAssign,
{
for (k, v) in rhs {
// Can't use `.and_modify(...).or_insert(...)` because we want to use `v` in both places.
match lhs.entry(k) {
Entry::Occupied(e) => *e.into_mut() += v,
Entry::Vacant(e) => drop(e.insert(v)),
}
lhs[k] += v;
}
}

impl AddAssign for ExecutionReport {
fn add_assign(&mut self, rhs: Self) {
hashmap_add_assign(&mut self.opcode_counts, rhs.opcode_counts);
hashmap_add_assign(&mut self.syscall_counts, rhs.syscall_counts);
counts_add_assign(&mut self.opcode_counts, *rhs.opcode_counts);
counts_add_assign(&mut self.syscall_counts, *rhs.syscall_counts);
self.touched_memory_addresses += rhs.touched_memory_addresses;
}
}
Expand All @@ -69,12 +66,12 @@ impl Add for ExecutionReport {
impl Display for ExecutionReport {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
writeln!(f, "opcode counts ({} total instructions):", self.total_instruction_count())?;
for line in sorted_table_lines(&self.opcode_counts) {
for line in sorted_table_lines(self.opcode_counts.as_ref()) {
writeln!(f, " {line}")?;
}

writeln!(f, "syscall counts ({} total syscall instructions):", self.total_syscall_count())?;
for line in sorted_table_lines(&self.syscall_counts) {
for line in sorted_table_lines(self.syscall_counts.as_ref()) {
writeln!(f, " {line}")?;
}

Expand Down
3 changes: 2 additions & 1 deletion crates/core/executor/src/syscalls/code.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use enum_map::Enum;
use serde::{Deserialize, Serialize};
use strum_macros::EnumIter;

Expand All @@ -18,7 +19,7 @@ use strum_macros::EnumIter;
/// memory accesses is bounded.
/// - Byte 3: Currently unused.
#[derive(
Debug, Copy, Clone, PartialEq, Eq, Hash, EnumIter, Ord, PartialOrd, Serialize, Deserialize,
Debug, Copy, Clone, PartialEq, Eq, Hash, EnumIter, Ord, PartialOrd, Serialize, Deserialize, Enum,
)]
#[allow(non_camel_case_types)]
#[allow(clippy::upper_case_acronyms)]
Expand Down
100 changes: 41 additions & 59 deletions crates/core/machine/src/riscv/cost.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,151 +33,133 @@ impl CostEstimator for ExecutionReport {
total_area += (cpu_events as u64) * costs[&RiscvAirDiscriminants::Cpu];
total_chips += 1;

let sha_extend_events = *self.syscall_counts.get(&SyscallCode::SHA_EXTEND).unwrap_or(&0);
let sha_extend_events = self.syscall_counts[SyscallCode::SHA_EXTEND];
total_area += (sha_extend_events as u64) * costs[&RiscvAirDiscriminants::Sha256Extend];
total_chips += 1;

let sha_compress_events =
*self.syscall_counts.get(&SyscallCode::SHA_COMPRESS).unwrap_or(&0);
let sha_compress_events = self.syscall_counts[SyscallCode::SHA_COMPRESS];
total_area += (sha_compress_events as u64) * costs[&RiscvAirDiscriminants::Sha256Compress];
total_chips += 1;

let ed_add_events = *self.syscall_counts.get(&SyscallCode::ED_ADD).unwrap_or(&0);
let ed_add_events = self.syscall_counts[SyscallCode::ED_ADD];
total_area += (ed_add_events as u64) * costs[&RiscvAirDiscriminants::Ed25519Add];
total_chips += 1;

let ed_decompress_events =
*self.syscall_counts.get(&SyscallCode::ED_DECOMPRESS).unwrap_or(&0);
let ed_decompress_events = self.syscall_counts[SyscallCode::ED_DECOMPRESS];
total_area +=
(ed_decompress_events as u64) * costs[&RiscvAirDiscriminants::Ed25519Decompress];
total_chips += 1;

let k256_decompress_events =
*self.syscall_counts.get(&SyscallCode::SECP256K1_DECOMPRESS).unwrap_or(&0);
let k256_decompress_events = self.syscall_counts[SyscallCode::SECP256K1_DECOMPRESS];
total_area +=
(k256_decompress_events as u64) * costs[&RiscvAirDiscriminants::K256Decompress];
total_chips += 1;

let secp256k1_add_events =
*self.syscall_counts.get(&SyscallCode::SECP256K1_ADD).unwrap_or(&0);
let secp256k1_add_events = self.syscall_counts[SyscallCode::SECP256K1_ADD];
total_area += (secp256k1_add_events as u64) * costs[&RiscvAirDiscriminants::Secp256k1Add];
total_chips += 1;

let secp256k1_double_events =
*self.syscall_counts.get(&SyscallCode::SECP256K1_DOUBLE).unwrap_or(&0);
let secp256k1_double_events = self.syscall_counts[SyscallCode::SECP256K1_DOUBLE];
total_area +=
(secp256k1_double_events as u64) * costs[&RiscvAirDiscriminants::Secp256k1Double];
total_chips += 1;

let keccak256_permute_events =
*self.syscall_counts.get(&SyscallCode::KECCAK_PERMUTE).unwrap_or(&0);
let keccak256_permute_events = self.syscall_counts[SyscallCode::KECCAK_PERMUTE];
total_area += (keccak256_permute_events as u64) * costs[&RiscvAirDiscriminants::KeccakP];
total_chips += 1;

let bn254_add_events = *self.syscall_counts.get(&SyscallCode::BN254_ADD).unwrap_or(&0);
let bn254_add_events = self.syscall_counts[SyscallCode::BN254_ADD];
total_area += (bn254_add_events as u64) * costs[&RiscvAirDiscriminants::Bn254Add];
total_chips += 1;

let bn254_double_events =
*self.syscall_counts.get(&SyscallCode::BN254_DOUBLE).unwrap_or(&0);
let bn254_double_events = self.syscall_counts[SyscallCode::BN254_DOUBLE];
total_area += (bn254_double_events as u64) * costs[&RiscvAirDiscriminants::Bn254Double];
total_chips += 1;

let bls12381_add_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_ADD).unwrap_or(&0);
let bls12381_add_events = self.syscall_counts[SyscallCode::BLS12381_ADD];
total_area += (bls12381_add_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Add];
total_chips += 1;

let bls12381_double_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_DOUBLE).unwrap_or(&0);
let bls12381_double_events = self.syscall_counts[SyscallCode::BLS12381_DOUBLE];
total_area +=
(bls12381_double_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Double];
total_chips += 1;

let uint256_mul_events = *self.syscall_counts.get(&SyscallCode::UINT256_MUL).unwrap_or(&0);
let uint256_mul_events = self.syscall_counts[SyscallCode::UINT256_MUL];
total_area += (uint256_mul_events as u64) * costs[&RiscvAirDiscriminants::Uint256Mul];
total_chips += 1;

let bls12381_fp_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_FP_ADD).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BLS12381_FP_SUB).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BLS12381_FP_MUL).unwrap_or(&0);
let bls12381_fp_events = self.syscall_counts[SyscallCode::BLS12381_FP_ADD]
+ self.syscall_counts[SyscallCode::BLS12381_FP_SUB]
+ self.syscall_counts[SyscallCode::BLS12381_FP_MUL];
total_area += (bls12381_fp_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Fp];
total_chips += 1;

let bls12381_fp2_addsub_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_FP2_ADD).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BLS12381_FP2_SUB).unwrap_or(&0);
let bls12381_fp2_addsub_events = self.syscall_counts[SyscallCode::BLS12381_FP2_ADD]
+ self.syscall_counts[SyscallCode::BLS12381_FP2_SUB];
total_area +=
(bls12381_fp2_addsub_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Fp2AddSub];
total_chips += 1;

let bls12381_fp2_mul_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_FP2_MUL).unwrap_or(&0);
let bls12381_fp2_mul_events = self.syscall_counts[SyscallCode::BLS12381_FP2_MUL];
total_area +=
(bls12381_fp2_mul_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Fp2Mul];
total_chips += 1;

let bn254_fp_events = *self.syscall_counts.get(&SyscallCode::BN254_FP_ADD).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BN254_FP_SUB).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BN254_FP_MUL).unwrap_or(&0);
let bn254_fp_events = self.syscall_counts[SyscallCode::BN254_FP_ADD]
+ self.syscall_counts[SyscallCode::BN254_FP_SUB]
+ self.syscall_counts[SyscallCode::BN254_FP_MUL];
total_area += (bn254_fp_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp];
total_chips += 1;

let bn254_fp2_addsub_events =
*self.syscall_counts.get(&SyscallCode::BN254_FP2_ADD).unwrap_or(&0)
+ *self.syscall_counts.get(&SyscallCode::BN254_FP2_SUB).unwrap_or(&0);
let bn254_fp2_addsub_events = self.syscall_counts[SyscallCode::BN254_FP2_ADD]
+ self.syscall_counts[SyscallCode::BN254_FP2_SUB];
total_area +=
(bn254_fp2_addsub_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp2AddSub];
total_chips += 1;

let bn254_fp2_mul_events =
*self.syscall_counts.get(&SyscallCode::BN254_FP2_MUL).unwrap_or(&0);
let bn254_fp2_mul_events = self.syscall_counts[SyscallCode::BN254_FP2_MUL];
total_area += (bn254_fp2_mul_events as u64) * costs[&RiscvAirDiscriminants::Bn254Fp2Mul];
total_chips += 1;

let bls12381_decompress_events =
*self.syscall_counts.get(&SyscallCode::BLS12381_DECOMPRESS).unwrap_or(&0);
let bls12381_decompress_events = self.syscall_counts[SyscallCode::BLS12381_DECOMPRESS];
total_area +=
(bls12381_decompress_events as u64) * costs[&RiscvAirDiscriminants::Bls12381Decompress];
total_chips += 1;

let divrem_events = *self.opcode_counts.get(&Opcode::DIV).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::REM).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::DIVU).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::REMU).unwrap_or(&0);
let divrem_events = self.opcode_counts[Opcode::DIV]
+ self.opcode_counts[Opcode::REM]
+ self.opcode_counts[Opcode::DIVU]
+ self.opcode_counts[Opcode::REMU];
total_area += (divrem_events as u64) * costs[&RiscvAirDiscriminants::DivRem];
total_chips += 1;

let addsub_events = *self.opcode_counts.get(&Opcode::ADD).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::SUB).unwrap_or(&0);
let addsub_events = self.opcode_counts[Opcode::ADD] + self.opcode_counts[Opcode::SUB];
total_area += (addsub_events as u64) * costs[&RiscvAirDiscriminants::Add];
total_chips += 1;

let bitwise_events = *self.opcode_counts.get(&Opcode::AND).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::OR).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::XOR).unwrap_or(&0);
let bitwise_events = self.opcode_counts[Opcode::AND]
+ self.opcode_counts[Opcode::OR]
+ self.opcode_counts[Opcode::XOR];
total_area += (bitwise_events as u64) * costs[&RiscvAirDiscriminants::Bitwise];
total_chips += 1;

let mul_events = *self.opcode_counts.get(&Opcode::MUL).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::MULH).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::MULHU).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::MULHSU).unwrap_or(&0);
let mul_events = self.opcode_counts[Opcode::MUL]
+ self.opcode_counts[Opcode::MULH]
+ self.opcode_counts[Opcode::MULHU]
+ self.opcode_counts[Opcode::MULHSU];
total_area += (mul_events as u64) * costs[&RiscvAirDiscriminants::Mul];
total_chips += 1;

let shift_right_events = *self.opcode_counts.get(&Opcode::SRL).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::SRA).unwrap_or(&0);
let shift_right_events = self.opcode_counts[Opcode::SRL] + self.opcode_counts[Opcode::SRA];
total_area += (shift_right_events as u64) * costs[&RiscvAirDiscriminants::ShiftRight];
total_chips += 1;

let shift_left_events = *self.opcode_counts.get(&Opcode::SLL).unwrap_or(&0);
let shift_left_events = self.opcode_counts[Opcode::SLL];
total_area += (shift_left_events as u64) * costs[&RiscvAirDiscriminants::ShiftLeft];
total_chips += 1;

let lt_events = *self.opcode_counts.get(&Opcode::SLT).unwrap_or(&0)
+ *self.opcode_counts.get(&Opcode::SLTU).unwrap_or(&0);
let lt_events = self.opcode_counts[Opcode::SLT] + self.opcode_counts[Opcode::SLTU];
total_area += (lt_events as u64) * costs[&RiscvAirDiscriminants::Lt];
total_chips += 1;

Expand Down
4 changes: 2 additions & 2 deletions crates/core/machine/src/utils/prove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,11 @@ where
// Print the opcode and syscall count tables like `du`: sorted by count (descending) and
// with the count in the first column.
tracing::info!("execution report (opcode counts):");
for line in sorted_table_lines(&report_aggregate.opcode_counts) {
for line in sorted_table_lines(report_aggregate.opcode_counts.as_ref()) {
tracing::info!(" {line}");
}
tracing::info!("execution report (syscall counts):");
for line in sorted_table_lines(&report_aggregate.syscall_counts) {
for line in sorted_table_lines(report_aggregate.syscall_counts.as_ref()) {
tracing::info!(" {line}");
}

Expand Down
Loading
Loading