Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: record nodes for writes in memtrie #10841

Merged
merged 7 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/store/src/trie/mem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl MemTries {
pub fn update(
&self,
root: CryptoHash,
track_disk_changes: bool,
track_trie_changes: bool,
) -> Result<MemTrieUpdate, StorageError> {
let root_id = if root == CryptoHash::default() {
None
Expand All @@ -163,7 +163,7 @@ impl MemTries {
root_id,
&self.arena.memory(),
self.shard_uid.to_string(),
track_disk_changes,
track_trie_changes,
))
}
}
Expand Down
103 changes: 72 additions & 31 deletions core/store/src/trie/mem/updating.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use crate::{NibbleSlice, RawTrieNode, RawTrieNodeWithSize, TrieChanges};
use near_primitives::hash::{hash, CryptoHash};
use near_primitives::state::FlatStateValue;
use near_primitives::types::BlockHeight;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

/// An old node means a node in the current in-memory trie. An updated node means a
/// node we're going to store in the in-memory trie but have not constructed there yet.
Expand Down Expand Up @@ -43,6 +44,28 @@ pub enum UpdatedMemTrieNode {
},
}

/// Keeps values and internal nodes accessed on updating memtrie.
pub(crate) struct TrieAccesses {
/// Hashes and encoded trie nodes.
pub nodes: HashMap<CryptoHash, Arc<[u8]>>,
/// Hashes of accessed values - because values themselves are not
/// necessarily present in memtrie.
pub values: HashSet<CryptoHash>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have an optional value, if the value is present in memtrie?

}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand what are the additional nodes that this change records.

AFAIU runtime updates the trie by calling storage_write, which always does a trie read before writing a value:

let evicted_ptr = self.ext.storage_get(&key, StorageGetMode::Trie)?;

This trie read records all nodes that were accessed to reach this value, even in the case of memtries:

if let Some(recorder) = &self.recorder {

Doesn't that record everything that is needed to prove execution of the contract? Why do we need an additional access log?

Copy link
Contributor

@jancionear jancionear Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I guess it doesn't record nodes that are created when adding new values...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this path is called only during contract execution. The gateway to update trie outside it is set<T: BorshSerialize>(state_update: &mut TrieUpdate, key: TrieKey, value: &T).


/// Tracks intermediate trie changes, final version of which is to be committed
/// to disk after finishing trie update.
struct TrieChangesTracker {
/// Changes of reference count on disk for each impacted node.
refcount_changes: TrieRefcountDeltaMap,
/// All observed values and internal nodes.
/// Needed to prepare recorded storage.
/// Note that negative `refcount_changes` does not fully cover it, as node
/// or value of the same hash can be removed and inserted for the same
/// update in different parts of trie!
accesses: TrieAccesses,
}

/// Structure to build an update to the in-memory trie.
pub struct MemTrieUpdate<'a> {
/// The original root before updates. It is None iff the original trie had no keys.
Expand All @@ -53,8 +76,9 @@ pub struct MemTrieUpdate<'a> {
/// (1) temporarily we take out the node from the slot to process it and put it back
/// later; or (2) the node is deleted afterwards.
pub updated_nodes: Vec<Option<UpdatedMemTrieNode>>,
/// Refcount changes to on-disk trie nodes.
pub trie_refcount_changes: Option<TrieRefcountDeltaMap>,
/// Tracks trie changes necessary to make on-disk updates and recorded
/// storage.
tracked_trie_changes: Option<TrieChangesTracker>,
}

impl UpdatedMemTrieNode {
Expand Down Expand Up @@ -97,15 +121,18 @@ impl<'a> MemTrieUpdate<'a> {
root: Option<MemTrieNodeId>,
arena: &'a ArenaMemory,
shard_uid: String,
track_disk_changes: bool,
track_trie_changes: bool,
) -> Self {
let mut trie_update = Self {
root,
arena,
shard_uid,
updated_nodes: vec![],
trie_refcount_changes: if track_disk_changes {
Some(TrieRefcountDeltaMap::new())
tracked_trie_changes: if track_trie_changes {
Some(TrieChangesTracker {
refcount_changes: TrieRefcountDeltaMap::new(),
accesses: TrieAccesses { nodes: HashMap::new(), values: HashSet::new() },
})
} else {
None
},
Expand Down Expand Up @@ -145,8 +172,16 @@ impl<'a> MemTrieUpdate<'a> {
match node {
None => self.new_updated_node(UpdatedMemTrieNode::Empty),
Some(node) => {
if let Some(trie_refcount_changes) = self.trie_refcount_changes.as_mut() {
trie_refcount_changes.subtract(node.as_ptr(self.arena).view().node_hash(), 1);
if let Some(tracked_trie_changes) = self.tracked_trie_changes.as_mut() {
let node_view = node.as_ptr(self.arena).view();
let node_hash = node_view.node_hash();
let raw_node_serialized =
borsh::to_vec(&node_view.to_raw_trie_node_with_size()).unwrap();
tracked_trie_changes
.accesses
.nodes
.insert(node_hash, raw_node_serialized.into());
tracked_trie_changes.refcount_changes.subtract(node_hash, 1);
}
self.new_updated_node(UpdatedMemTrieNode::from_existing_node_view(
node.as_ptr(self.arena).view(),
Expand All @@ -164,14 +199,15 @@ impl<'a> MemTrieUpdate<'a> {
}

fn add_refcount_to_value(&mut self, hash: CryptoHash, value: Option<Vec<u8>>) {
if let Some(trie_refcount_changes) = self.trie_refcount_changes.as_mut() {
trie_refcount_changes.add(hash, value.unwrap(), 1);
if let Some(tracked_node_changes) = self.tracked_trie_changes.as_mut() {
tracked_node_changes.refcount_changes.add(hash, value.unwrap(), 1);
}
}

fn subtract_refcount_for_value(&mut self, hash: CryptoHash) {
if let Some(trie_refcount_changes) = self.trie_refcount_changes.as_mut() {
trie_refcount_changes.subtract(hash, 1);
if let Some(tracked_node_changes) = self.tracked_trie_changes.as_mut() {
tracked_node_changes.accesses.values.insert(hash);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar to below comment, does subtracting refcount for the value guarantee that we must need the previous value of the value? If so, why?

tracked_node_changes.refcount_changes.subtract(hash, 1);
}
}

Expand Down Expand Up @@ -779,31 +815,36 @@ impl<'a> MemTrieUpdate<'a> {
}

/// Converts the updates to trie changes as well as memtrie changes.
pub fn to_trie_changes(self) -> TrieChanges {
let Self { root, arena, shard_uid, trie_refcount_changes, updated_nodes } = self;
let mut trie_refcount_changes =
trie_refcount_changes.expect("Cannot to_trie_changes for memtrie changes only");
pub(crate) fn to_trie_changes(self) -> (TrieChanges, TrieAccesses) {
let Self { root, arena, shard_uid, tracked_trie_changes, updated_nodes } = self;
let TrieChangesTracker { mut refcount_changes, accesses } =
tracked_trie_changes.expect("Cannot to_trie_changes for memtrie changes only");
let (mem_trie_changes, hashes_and_serialized) =
Self::to_mem_trie_changes_internal(shard_uid, arena, updated_nodes);

// We've accounted for the dereferenced nodes, as well as value addition/subtractions.
// The only thing left is to increment refcount for all new nodes.
for (node_hash, node_serialized) in hashes_and_serialized {
trie_refcount_changes.add(node_hash, node_serialized, 1);
}
let (insertions, deletions) = trie_refcount_changes.into_changes();

TrieChanges {
old_root: root.map(|root| root.as_ptr(arena).view().node_hash()).unwrap_or_default(),
new_root: mem_trie_changes
.node_ids_with_hashes
.last()
.map(|(_, hash)| *hash)
.unwrap_or_default(),
insertions,
deletions,
mem_trie_changes: Some(mem_trie_changes),
refcount_changes.add(node_hash, node_serialized, 1);
}
let (insertions, deletions) = refcount_changes.into_changes();

(
TrieChanges {
old_root: root
.map(|root| root.as_ptr(arena).view().node_hash())
.unwrap_or_default(),
new_root: mem_trie_changes
.node_ids_with_hashes
.last()
.map(|(_, hash)| *hash)
.unwrap_or_default(),
insertions,
deletions,
mem_trie_changes: Some(mem_trie_changes),
},
accesses,
)
}
}

Expand Down Expand Up @@ -917,7 +958,7 @@ mod tests {
update.delete(&key);
}
}
update.to_trie_changes()
update.to_trie_changes().0
}

fn make_memtrie_changes_only(
Expand Down
31 changes: 30 additions & 1 deletion core/store/src/trie/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,36 @@ impl Trie {
None => trie_update.delete(&key),
}
}
Ok(trie_update.to_trie_changes())
let (trie_changes, trie_accesses) = trie_update.to_trie_changes();

// Sanity check for tests: all modified trie items must be
// present in ever accessed trie items.
#[cfg(test)]
{
for t in trie_changes.deletions.iter() {
let hash = t.trie_node_or_value_hash;
assert!(
trie_accesses.values.contains(&hash)
|| trie_accesses.nodes.contains_key(&hash),
"Hash {} is not present in trie accesses",
hash
);
}
}

// Retroactively record all accessed trie items to account for
// key-value pairs which were only written but never read, thus
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a gap in my logical understanding: where is the link that goes from "a key is written to but not read from" to "we must have the previous value for that key"?

For the specific test case that I knew of of a trie restructuring (converting branch of 1 child to extension, I know that the child node must be read in order to do this conversion, and at the same time the child node is deleted, so in this case, yes, the child is written to (as in, its refcount was decremented) but not read from (as in, did not get queried), and the child needed its previous value (in order to fulfill the restructuring)), but is that in general the only case where the update phase would cause a value to be written to but not read from?

So in other words, is the logic this: "if a node is written to but not read from, then that node must have been replaced during a trie restructuring, and therefore its previous value must be known" (which I need to be convinced of), or is there some other way to draw this conclusion?

Copy link
Member Author

@Longarithm Longarithm Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a node is written to but not read from, then that node must have been replaced during a trie restructuring, and therefore its previous value must be known

I think "key" got replaced with "node" somewhere in the message.

The logic is "if a key is written to but not read from" => "all nodes on the trie path must be known".
Let's say you don't know some node on the path to it and take the lowest of them. Then you can't recompute it after write, because you don't know hashes of neighbours of the key, so you can't compute new state root.

And, yeah, when we descend into node corresponding to key, all refcounts on the way are decremented.

I actually want to propose simpler logic by this PR: "for all keys, for which any state operation was called (get/has/write/remove), all nodes to the old path must be recorded". That's what current disk trie logic does, and that's easy to explain and follow.
For memtries, union of (nodes recorded on chunk application) and (nodes recorded on trie restructuring) should give this set. So I didn't even think much if this is completely necessary, it's enough for me that it is verifiable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, ok, thanks for clarifying your intent in the mathematical form :)

I think my question is a bit more detailed than this.

Suppose we replace a node from the trie. Do we actually need the node that we're replacing? The parent of that node would have the hashes of the sibling nodes that we need to reconstruct the parent. So for example we have A -> B -> C -> D and we're deleting D. Do we need to have D, because to reconstruct a A' -> B' -> C' -> D', we only need A, B, C, and D'. D is not needed. But in the logic in this PR, it seems that it requires that D be needed. That's the part I don't understand.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense.
Then, my argument goes to:

  • that's what current disk logic do. Every time we descend to node, we pull it out by move_node_to_mutable or delete_value. By default I would avoid the logic in couple places to "if this is the last node, don't do it" + adding cornercase when this node is a restructured branch.
  • the logic "we take all nodes on the path" is simpler to explain than "we take all nodes on reads + all nodes except the last one on writes"
  • savings from not including previous value shouldn't be high, because it is a rare case to modify value without reading it. action_deploy_contract reads contract before redeploying it, etc. Still we can store only ValueRefs of prev values, but I'm not 100% confident I'll do it right the first time :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are recording node D in A -> B -> C -> D for on disk tries, let's try to keep consistency and do the same for memtries even if it's not the most efficient.

// not recorded before.
if let Some(recorder) = &self.recorder {
for (node_hash, serialized_node) in trie_accesses.nodes {
recorder.borrow_mut().record(&node_hash, serialized_node);
}
for value_hash in trie_accesses.values {
let value = self.storage.retrieve_raw_bytes(&value_hash)?;
recorder.borrow_mut().record(&value_hash, value);
}
}
Ok(trie_changes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's sad that the changes are recorded retroactively.
For state witness size limit it would be best to record them as they happen, this would make it possible to stop executing a receipt once it generates more than X MB of PartialState. With retroactive recording we can only measure how much PartialState was generated after the update is applied, which AFAIU happens after applying the whole chunk :/

Actually, doesn't this break #10703? The soft size limit looks at the size of recorder state after applying each receipt and stops executing new ones when this size gets above the limit. It depends on online recording, it wouldn't work with retroactive one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to somehow put the recording logic in TrieUpdate::set?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I guess it works the same way with normal trie writes. That sucks, it really throws a wrench into the witness size limit. This solution makes sense for now, but I think the whole logic of trie updating will need a refactor for the size limit :/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay actually we don't need to record writes for the state witness, only the reads need to be recorded :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in the meeting today, turns out we do want to record the nodes for updation/deletion :(
I believe this update function is only called in trie_update finalize for all practical purposes?

We would need to revisit this and get a better solution for recording on the fly like in the get function instead of trie_update finalize. Required for both soft and hard limit state witness.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently (other than trie restructuring), I guess we just rely on the fact that runtime first calls a get before update/delete for gas cost estimation.

}
None => {
let mut memory = NodesStorage::new();
Expand Down
32 changes: 28 additions & 4 deletions core/store/src/trie/trie_recording.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ mod trie_recording_tests {
use crate::trie::mem::metrics::MEM_TRIE_NUM_LOOKUPS;
use crate::trie::TrieNodesCount;
use crate::{DBCol, Store, Trie};
use borsh::BorshDeserialize;
use near_primitives::hash::{hash, CryptoHash};
use near_primitives::shard_layout::{get_block_shard_uid, get_block_shard_uid_rev, ShardUId};
use near_primitives::shard_layout::{get_block_shard_uid, ShardUId};
use near_primitives::state::ValueRef;
use near_primitives::types::chunk_extra::ChunkExtra;
use near_primitives::types::StateRoot;
use rand::{thread_rng, Rng};
use rand::{random, thread_rng, Rng};
use std::collections::{HashMap, HashSet};
use std::num::NonZeroU32;

Expand All @@ -66,6 +67,8 @@ mod trie_recording_tests {
/// The keys that we should be using to call get_optimized_ref() on the
/// trie with.
keys_to_get_ref: Vec<Vec<u8>>,
/// The keys to be updated after trie reads.
updates: Vec<(Vec<u8>, Option<Vec<u8>>)>,
state_root: StateRoot,
}

Expand Down Expand Up @@ -121,13 +124,26 @@ mod trie_recording_tests {
}
key
})
.partition::<Vec<_>, _>(|_| thread_rng().gen());
.partition::<Vec<_>, _>(|_| random());
let updates = trie_changes
.iter()
.map(|(key, _)| {
let value = if thread_rng().gen_bool(0.5) {
Some(vec![thread_rng().gen_range(0..10) as u8])
} else {
None
};
(key.clone(), value)
})
.filter(|_| random())
.collect::<Vec<_>>();
PreparedTrie {
store: tries_for_building.get_store(),
shard_uid,
data_in_trie,
keys_to_get,
keys_to_get_ref,
updates,
state_root,
}
}
Expand All @@ -146,7 +162,7 @@ mod trie_recording_tests {
for result in store.iter_raw_bytes(DBCol::State) {
let (key, value) = result.unwrap();
let (_, refcount) = decode_value_with_rc(&value);
let (key_hash, _) = get_block_shard_uid_rev(&key).unwrap();
let key_hash: CryptoHash = CryptoHash::try_from_slice(&key[8..]).unwrap();
if !key_hashes_to_keep.contains(&key_hash) {
update.decrement_refcount_by(
DBCol::State,
Expand Down Expand Up @@ -174,6 +190,7 @@ mod trie_recording_tests {
data_in_trie,
keys_to_get,
keys_to_get_ref,
updates,
state_root,
} = prepare_trie(use_missing_keys);
let tries = if use_in_memory_tries {
Expand Down Expand Up @@ -206,6 +223,7 @@ mod trie_recording_tests {
}
let baseline_trie_nodes_count = trie.get_trie_nodes_count();
println!("Baseline trie nodes count: {:?}", baseline_trie_nodes_count);
trie.update(updates.iter().cloned()).unwrap();

// Now let's do this again while recording, and make sure that the counters
// we get are exactly the same.
Expand All @@ -223,6 +241,7 @@ mod trie_recording_tests {
);
}
assert_eq!(trie.get_trie_nodes_count(), baseline_trie_nodes_count);
trie.update(updates.iter().cloned()).unwrap();

// Now, let's check that when doing the same lookups with the captured partial storage,
// we still get the same counters.
Expand All @@ -246,6 +265,7 @@ mod trie_recording_tests {
);
}
assert_eq!(trie.get_trie_nodes_count(), baseline_trie_nodes_count);
trie.update(updates.iter().cloned()).unwrap();

if use_in_memory_tries {
// sanity check that we did indeed use in-memory tries.
Expand Down Expand Up @@ -310,6 +330,7 @@ mod trie_recording_tests {
data_in_trie,
keys_to_get,
keys_to_get_ref,
updates,
state_root,
} = prepare_trie(use_missing_keys);
let tries = if use_in_memory_tries {
Expand Down Expand Up @@ -364,6 +385,7 @@ mod trie_recording_tests {
}
let baseline_trie_nodes_count = trie.get_trie_nodes_count();
println!("Baseline trie nodes count: {:?}", baseline_trie_nodes_count);
trie.update(updates.iter().cloned()).unwrap();

// Let's do this again, but this time recording reads. We'll make sure
// the counters are exactly the same even when we're recording.
Expand All @@ -388,6 +410,7 @@ mod trie_recording_tests {
);
}
assert_eq!(trie.get_trie_nodes_count(), baseline_trie_nodes_count);
trie.update(updates.iter().cloned()).unwrap();

// Now, let's check that when doing the same lookups with the captured partial storage,
// we still get the same counters.
Expand All @@ -411,6 +434,7 @@ mod trie_recording_tests {
);
}
assert_eq!(trie.get_trie_nodes_count(), baseline_trie_nodes_count);
trie.update(updates.iter().cloned()).unwrap();

if use_in_memory_tries {
// sanity check that we did indeed use in-memory tries.
Expand Down
8 changes: 5 additions & 3 deletions core/store/src/trie/trie_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,12 @@ mod trie_storage_tests {
assert_eq!(count_delta.mem_reads, 1);
}

// TODO(#10769): Make this test pass.
// Checks that for keys only touched on writes recorded storage for
// memtrie matches recorded storage for disk.
// Required because recording on read and write happen on different code
// paths for memtrie.
#[test]
#[should_panic]
fn test_memtrie_discrepancy() {
fn test_memtrie_recorded_writes() {
init_test_logger();
let tries = TestTriesBuilder::new().build();
let shard_uid = ShardUId::single_shard();
Expand Down
Loading