Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: enable the mpt cache #62

Merged
merged 2 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 14 additions & 26 deletions primitives/src/mpt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

use core::{
cell::RefCell,
cmp,
fmt::{Debug, Write},
iter, mem,
Expand Down Expand Up @@ -131,10 +132,10 @@ pub fn keccak(data: impl AsRef<[u8]>) -> [u8; 32] {
pub struct MptNode {
/// The type and data of the node.
data: MptNodeData,
// / Cache for a previously computed reference of this node. This is skipped during
// / serialization.
// #[serde(skip)]
// cached_reference: RefCell<Option<MptNodeReference>>,
/// Cache for a previously computed reference of this node. This is skipped during
/// serialization.
#[serde(skip)]
cached_reference: RefCell<Option<MptNodeReference>>,
}

/// Represents custom error types for the sparse Merkle Patricia Trie (MPT).
Expand Down Expand Up @@ -209,7 +210,7 @@ impl From<MptNodeData> for MptNode {
fn from(value: MptNodeData) -> Self {
Self {
data: value,
// cached_reference: RefCell::new(None),
cached_reference: RefCell::new(None),
}
}
}
Expand Down Expand Up @@ -371,11 +372,10 @@ impl MptNode {
/// storage or transmission purposes.
#[inline]
pub fn reference(&self) -> MptNodeReference {
// self.cached_reference
// .borrow_mut()
// .get_or_insert_with(|| self.calc_reference())
// .clone()
self.calc_reference()
self.cached_reference
.borrow_mut()
.get_or_insert_with(|| self.calc_reference())
.clone()
}

/// Computes and returns the 256-bit hash of the node.
Expand All @@ -385,11 +385,7 @@ impl MptNode {
pub fn hash(&self) -> B256 {
match self.data {
MptNodeData::Null => EMPTY_ROOT,
// _ => match self
// .cached_reference
// .borrow_mut()
// .get_or_insert_with(|| self.calc_reference())
_ => match self.calc_reference() {
_ => match self.reference() {
MptNodeReference::Digest(digest) => digest,
MptNodeReference::Bytes(bytes) => keccak(bytes).into(),
},
Expand All @@ -398,11 +394,7 @@ impl MptNode {

/// Encodes the [MptNodeReference] of this node into the `out` buffer.
fn reference_encode(&self, out: &mut dyn alloy_rlp::BufMut) {
// match self
// .cached_reference
// .borrow_mut()
// .get_or_insert_with(|| self.calc_reference())
match self.calc_reference() {
match self.reference() {
// if the reference is an RLP-encoded byte slice, copy it directly
MptNodeReference::Bytes(bytes) => out.put_slice(&bytes),
// if the reference is a digest, RLP-encode it with its fixed known length
Expand All @@ -415,11 +407,7 @@ impl MptNode {

/// Returns the length of the encoded [MptNodeReference] of this node.
fn reference_length(&self) -> usize {
// match self
// .cached_reference
// .borrow_mut()
// .get_or_insert_with(|| self.calc_reference())
match self.calc_reference() {
match self.reference() {
MptNodeReference::Bytes(bytes) => bytes.len(),
MptNodeReference::Digest(_) => 1 + 32,
}
Expand Down Expand Up @@ -774,7 +762,7 @@ impl MptNode {
}

fn invalidate_ref_cache(&mut self) {
// self.cached_reference.borrow_mut().take();
self.cached_reference.borrow_mut().take();
}

/// Returns the number of traversable nodes in the trie.
Expand Down
6 changes: 3 additions & 3 deletions raiko-host/src/prover/proof/risc0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@ pub async fn execute_risc0(
req: &Risc0ProofParams,
) -> Result<Risc0Response, String> {
println!("elf code length: {}", RISC0_METHODS_ELF.len());
let encoded_input = to_vec(&input).expect("Could not serialize proving input!");

let result = maybe_prove::<GuestInput<EthereumTxEssence>, GuestOutput>(
req,
&input,
encoded_input,
RISC0_METHODS_ELF,
&output,
Default::default(),
Expand Down Expand Up @@ -227,13 +228,12 @@ pub async fn verify_bonsai_receipt<O: Eq + Debug + DeserializeOwned>(

pub async fn maybe_prove<I: Serialize, O: Eq + Debug + Serialize + DeserializeOwned>(
req: &Risc0ProofParams,
input: &I,
encoded_input: Vec<u32>,
elf: &[u8],
expected_output: &O,
assumptions: (Vec<Assumption>, Vec<String>),
) -> Option<(String, Receipt)> {
let (assumption_instances, assumption_uuids) = assumptions;
let encoded_input = to_vec(input).expect("Could not serialize proving input!");

let encoded_output =
to_vec(expected_output).expect("Could not serialize expected proving output!");
Expand Down
Loading