diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index adc9dce2a89..3999fbef3e0 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -1737,15 +1737,18 @@ impl Chain { let shard_state_header = self.get_state_header(shard_id, sync_hash)?; let mut height = shard_state_header.chunk_height_included(); let state_root = shard_state_header.chunk_prev_state_root(); - let mut parts = vec![]; for part_id in 0..num_parts { let key = StatePartKey(sync_hash, shard_id, part_id).try_to_vec()?; - parts.push(self.store.owned_store().get(ColStateParts, &key)?.unwrap()); + let part = self.store.owned_store().get(ColStateParts, &key)?.unwrap(); + self.runtime_adapter.apply_state_part( + shard_id, + &state_root, + part_id, + num_parts, + &part, + )?; } - // Confirm that state matches the parts we received - self.runtime_adapter.confirm_state(shard_id, &state_root, &parts)?; - // Applying the chunk starts here let mut chain_update = self.chain_update(); chain_update.set_state_finalize(shard_id, sync_hash, shard_state_header)?; diff --git a/chain/chain/src/test_utils.rs b/chain/chain/src/test_utils.rs index d8743341485..3fc65dec75c 100644 --- a/chain/chain/src/test_utils.rs +++ b/chain/chain/src/test_utils.rs @@ -816,15 +816,12 @@ impl RuntimeAdapter for KeyValueRuntime { num_parts: u64, ) -> Result, Error> { assert!(part_id < num_parts); + if part_id != 0 { + return Ok(vec![]); + } let state = self.state.read().unwrap().get(&state_root).unwrap().clone(); let data = state.try_to_vec().expect("should never fall"); - let state_size = data.len() as u64; - let begin = state_size / num_parts * part_id; - let mut end = state_size / num_parts * (part_id + 1); - if part_id + 1 == num_parts { - end = state_size; - } - Ok(data[begin as usize..end as usize].to_vec()) + Ok(data) } fn validate_state_part( @@ -839,18 +836,18 @@ impl RuntimeAdapter for KeyValueRuntime { true } - fn confirm_state( + fn apply_state_part( &self, _shard_id: ShardId, state_root: &StateRoot, - parts: &Vec>, + part_id: u64, + _num_parts: u64, + data: &[u8], ) -> Result<(), Error> { - let mut data = vec![]; - for part in parts { - data.push(part.clone()); + if part_id != 0 { + return Ok(()); } - let data_flatten: Vec = data.iter().flatten().cloned().collect(); - let state = KVState::try_from_slice(&data_flatten).unwrap(); + let state = KVState::try_from_slice(data).unwrap(); self.state.write().unwrap().insert(state_root.clone(), state.clone()); let data = state.try_to_vec()?; let state_size = data.len() as u64; diff --git a/chain/chain/src/types.rs b/chain/chain/src/types.rs index f256969bff1..3c6b7612845 100644 --- a/chain/chain/src/types.rs +++ b/chain/chain/src/types.rs @@ -566,11 +566,13 @@ pub trait RuntimeAdapter: Send + Sync { ) -> bool; /// Should be executed after accepting all the parts to set up a new state. - fn confirm_state( + fn apply_state_part( &self, shard_id: ShardId, state_root: &StateRoot, - parts: &Vec>, + part_id: u64, + num_parts: u64, + part: &[u8], ) -> Result<(), Error>; /// Returns StateRootNode of a state. diff --git a/chain/client/tests/process_blocks.rs b/chain/client/tests/process_blocks.rs index a04e7aa811f..52de4c37b0a 100644 --- a/chain/client/tests/process_blocks.rs +++ b/chain/client/tests/process_blocks.rs @@ -1401,7 +1401,7 @@ fn test_process_block_after_state_sync() { env.clients[0].chain.reset_data_pre_state_sync(sync_hash).unwrap(); env.clients[0] .runtime_adapter - .confirm_state(0, &chunk_extra.state_root, &vec![state_part]) + .apply_state_part(0, &chunk_extra.state_root, 0, 1, &state_part) .unwrap(); let block = env.clients[0].produce_block(sync_height + 1).unwrap().unwrap(); let (_, res) = env.clients[0].process_block(block, Provenance::PRODUCED); diff --git a/core/store/src/trie/iterator.rs b/core/store/src/trie/iterator.rs index 20bf19224e8..e8a1e645146 100644 --- a/core/store/src/trie/iterator.rs +++ b/core/store/src/trie/iterator.rs @@ -1,7 +1,7 @@ use near_primitives::hash::CryptoHash; use crate::trie::nibble_slice::NibbleSlice; -use crate::trie::{NodeHandle, TrieNode, TrieNodeWithSize, ValueHandle}; +use crate::trie::{TrieNode, TrieNodeWithSize, ValueHandle}; use crate::{StorageError, Trie}; #[derive(Debug)] @@ -39,7 +39,7 @@ pub struct TrieIterator<'a> { root: CryptoHash, } -pub type TrieItem<'a> = Result<(Vec, Vec), StorageError>; +pub type TrieItem = Result<(Vec, Vec), StorageError>; impl<'a> TrieIterator<'a> { #![allow(clippy::new_ret_no_self)] @@ -52,95 +52,74 @@ impl<'a> TrieIterator<'a> { root: *root, }; let node = trie.retrieve_node(root)?; - r.descend_into_node(&node); + r.descend_into_node(node); Ok(r) } /// Position the iterator on the first element with key => `key`. pub fn seek>(&mut self, key: K) -> Result<(), StorageError> { - self.seek_nibble_slice(NibbleSlice::new(key.as_ref())) + self.seek_nibble_slice(NibbleSlice::new(key.as_ref())).map(drop) } + /// Returns the hash of the last node pub(crate) fn seek_nibble_slice( &mut self, mut key: NibbleSlice<'_>, - ) -> Result<(), StorageError> { + ) -> Result { self.trail.clear(); self.key_nibbles.clear(); - let mut hash = NodeHandle::Hash(self.root); + let mut hash = self.root; loop { - let node = match hash { - NodeHandle::Hash(hash) => self.trie.retrieve_node(&hash)?, - NodeHandle::InMemory(_node) => unreachable!(), - }; - let copy_node = node.clone(); - match node.node { - TrieNode::Empty => return Ok(()), + let node = self.trie.retrieve_node(&hash)?; + self.trail.push(Crumb { status: CrumbStatus::Entering, node }); + let Crumb { status, node } = self.trail.last_mut().unwrap(); + match &node.node { + TrieNode::Empty => break, TrieNode::Leaf(leaf_key, _) => { let existing_key = NibbleSlice::from_encoded(&leaf_key).0; - self.trail.push(Crumb { - status: if existing_key >= key { - CrumbStatus::Entering - } else { - CrumbStatus::Exiting - }, - node: copy_node, - }); - self.key_nibbles.extend(existing_key.iter()); - return Ok(()); + if existing_key < key { + self.key_nibbles.extend(existing_key.iter()); + *status = CrumbStatus::Exiting; + } + break; } - TrieNode::Branch(mut children, _) => { + TrieNode::Branch(children, _) => { if key.is_empty() { - self.trail.push(Crumb { status: CrumbStatus::Entering, node: copy_node }); - return Ok(()); + break; } else { let idx = key.at(0) as usize; - self.trail.push(Crumb { - status: CrumbStatus::AtChild(idx as usize), - node: copy_node, - }); self.key_nibbles.push(key.at(0)); - if let Some(child) = children[idx].take() { - hash = child; + *status = CrumbStatus::AtChild(idx as usize); + if let Some(child) = &children[idx] { + hash = *child.unwrap_hash(); key = key.mid(1); } else { - return Ok(()); + break; } } } TrieNode::Extension(ext_key, child) => { let existing_key = NibbleSlice::from_encoded(&ext_key).0; if key.starts_with(&existing_key) { - self.trail.push(Crumb { status: CrumbStatus::At, node: copy_node }); - self.key_nibbles.extend(existing_key.iter()); - hash = child; key = key.mid(existing_key.len()); - } else { - self.trail.push(Crumb { - status: if existing_key >= key { - CrumbStatus::Entering - } else { - CrumbStatus::Exiting - }, - node: copy_node, - }); + hash = *child.unwrap_hash(); + *status = CrumbStatus::At; self.key_nibbles.extend(existing_key.iter()); - return Ok(()); + } else { + if existing_key < key { + *status = CrumbStatus::Exiting; + self.key_nibbles.extend(existing_key.iter()); + } + break; } } } } + Ok(hash) } - fn descend_into_node(&mut self, node: &TrieNodeWithSize) { - self.trail.push(Crumb { status: CrumbStatus::Entering, node: node.clone() }); - match &self.trail.last().expect("Just pushed item").node.node { - TrieNode::Leaf(ref key, _) | TrieNode::Extension(ref key, _) => { - let key = NibbleSlice::from_encoded(key).0; - self.key_nibbles.extend(key.iter()); - } - _ => {} - } + fn descend_into_node(&mut self, node: TrieNodeWithSize) { + self.trail.push(Crumb { status: CrumbStatus::Entering, node }); } fn key(&self) -> Vec { @@ -150,94 +129,215 @@ impl<'a> TrieIterator<'a> { } result } -} -impl<'a> Iterator for TrieIterator<'a> { - type Item = TrieItem<'a>; + fn iter_step(&mut self) -> Option { + self.trail.last_mut()?.increment(); + let b = self.trail.last().expect("Trail finished."); + match (b.status.clone(), &b.node.node) { + (CrumbStatus::Exiting, n) => { + match n { + TrieNode::Leaf(ref key, _) | TrieNode::Extension(ref key, _) => { + let existing_key = NibbleSlice::from_encoded(&key).0; + let l = self.key_nibbles.len(); + self.key_nibbles.truncate(l - existing_key.len()); + } + TrieNode::Branch(_, _) => { + self.key_nibbles.pop(); + } + _ => {} + } + Some(IterStep::PopTrail) + } + (CrumbStatus::At, TrieNode::Branch(_, Some(value))) => { + let hash = match value { + ValueHandle::HashAndSize(_, hash) => *hash, + ValueHandle::InMemory(_node) => unreachable!(), + }; + Some(IterStep::Value(hash)) + } + (CrumbStatus::At, TrieNode::Branch(_, None)) => Some(IterStep::Continue), + (CrumbStatus::At, TrieNode::Leaf(key, value)) => { + let hash = match value { + ValueHandle::HashAndSize(_, hash) => *hash, + ValueHandle::InMemory(_node) => unreachable!(), + }; + let key = NibbleSlice::from_encoded(&key).0; + self.key_nibbles.extend(key.iter()); + Some(IterStep::Value(hash)) + } + (CrumbStatus::At, TrieNode::Extension(key, child)) => { + let hash = *child.unwrap_hash(); + let key = NibbleSlice::from_encoded(&key).0; + self.key_nibbles.extend(key.iter()); + Some(IterStep::Descend(hash)) + } + (CrumbStatus::AtChild(i), TrieNode::Branch(children, _)) if children[i].is_some() => { + match i { + 0 => self.key_nibbles.push(0), + i => *self.key_nibbles.last_mut().expect("Pushed child value before") = i as u8, + } + let hash = *children[i].as_ref().unwrap().unwrap_hash(); + Some(IterStep::Descend(hash)) + } + (CrumbStatus::AtChild(i), TrieNode::Branch(_, _)) => { + if i == 0 { + self.key_nibbles.push(0); + } + Some(IterStep::Continue) + } + _ => panic!("Should never see Entering or AtChild without a Branch here."), + } + } - fn next(&mut self) -> Option { - enum IterStep { - Continue, - PopTrail, - Descend(Result, StorageError>), + fn common_prefix(str1: &[u8], str2: &[u8]) -> usize { + let mut prefix = 0; + while prefix < str1.len() && prefix < str2.len() && str1[prefix] == str2[prefix] { + prefix += 1; + } + prefix + } + + /// Returns hashes of nodes with paths in [path_begin, path_end). Used by state parts + pub(crate) fn visit_nodes_interval( + &mut self, + path_begin: &[u8], + path_end: &[u8], + ) -> Result, StorageError> { + let path_begin_encoded = NibbleSlice::encode_nibbles(path_begin, true); + let last_hash = self.seek_nibble_slice(NibbleSlice::from_encoded(&path_begin_encoded).0)?; + let mut prefix = Self::common_prefix(path_end, &self.key_nibbles); + if self.key_nibbles[prefix..] >= path_end[prefix..] { + return Ok(vec![]); + } + let mut nodes_list = Vec::new(); + // Actually (self.key_nibbles[..] == path_begin) always because path_begin always ends in a node + if &self.key_nibbles[..] >= path_begin { + nodes_list.push(last_hash); } + loop { - let iter_step = { - self.trail.last_mut()?.increment(); - let b = self.trail.last().expect("Trail finished."); - match (b.status.clone(), &b.node.node) { - (CrumbStatus::Exiting, n) => { - match n { - TrieNode::Leaf(ref key, _) | TrieNode::Extension(ref key, _) => { - let existing_key = NibbleSlice::from_encoded(&key).0; - let l = self.key_nibbles.len(); - self.key_nibbles.truncate(l - existing_key.len()); - } - TrieNode::Branch(_, _) => { - self.key_nibbles.pop(); - } - _ => {} - } - IterStep::PopTrail - } - (CrumbStatus::At, TrieNode::Branch(_, Some(value))) => { - let value = match value { - ValueHandle::HashAndSize(_, hash) => self.trie.retrieve_raw_bytes(hash), - ValueHandle::InMemory(_node) => unreachable!(), - }; - return Some(value.map(|value| (self.key(), value))); - } - (CrumbStatus::At, TrieNode::Branch(_, None)) => IterStep::Continue, - (CrumbStatus::At, TrieNode::Leaf(_, value)) => { - let value = match value { - ValueHandle::HashAndSize(_, hash) => self.trie.retrieve_raw_bytes(hash), - ValueHandle::InMemory(_node) => unreachable!(), - }; - return Some(value.map(|value| (self.key(), value))); - } - (CrumbStatus::At, TrieNode::Extension(_, child)) => { - let next_node = match child { - NodeHandle::Hash(hash) => self.trie.retrieve_node(hash).map(Box::new), - NodeHandle::InMemory(_node) => unreachable!(), - }; - IterStep::Descend(next_node) - } - (CrumbStatus::AtChild(i), TrieNode::Branch(children, _)) - if children[i].is_some() => - { - match i { - 0 => self.key_nibbles.push(0), - i => { - *self.key_nibbles.last_mut().expect("Pushed child value before") = - i as u8 - } - } - let next_node = match &children[i] { - Some(NodeHandle::Hash(hash)) => { - self.trie.retrieve_node(&hash).map(Box::new) - } - Some(NodeHandle::InMemory(_node)) => unreachable!(), - _ => panic!("Wrapped with is_some()"), - }; - IterStep::Descend(next_node) - } - (CrumbStatus::AtChild(i), TrieNode::Branch(_, _)) => { - if i == 0 { - self.key_nibbles.push(0); - } - IterStep::Continue + let iter_step = match self.iter_step() { + Some(iter_step) => iter_step, + None => break, + }; + match iter_step { + IterStep::PopTrail => { + self.trail.pop(); + prefix = std::cmp::min(self.key_nibbles.len(), prefix); + } + IterStep::Descend(hash) => { + prefix += Self::common_prefix(&path_end[prefix..], &self.key_nibbles[prefix..]); + if self.key_nibbles[prefix..] >= path_end[prefix..] { + break; } - _ => panic!("Should never see Entering or AtChild without a Branch here."), + let node = self.trie.retrieve_node(&hash)?; + self.descend_into_node(node); + nodes_list.push(hash); } - }; + IterStep::Continue => {} + IterStep::Value(hash) => { + self.trie.retrieve_raw_bytes(&hash)?; + nodes_list.push(hash); + } + } + } + Ok(nodes_list) + } +} + +enum IterStep { + Continue, + PopTrail, + Descend(CryptoHash), + Value(CryptoHash), +} + +impl<'a> Iterator for TrieIterator<'a> { + type Item = TrieItem; + + fn next(&mut self) -> Option { + loop { + let iter_step = self.iter_step()?; match iter_step { IterStep::PopTrail => { self.trail.pop(); } - IterStep::Descend(Ok(node)) => self.descend_into_node(&node), - IterStep::Descend(Err(e)) => return Some(Err(e)), + IterStep::Descend(hash) => match self.trie.retrieve_node(&hash) { + Ok(node) => self.descend_into_node(node), + Err(e) => return Some(Err(e)), + }, IterStep::Continue => {} + IterStep::Value(hash) => { + return Some( + self.trie.retrieve_raw_bytes(&hash).map(|value| (self.key(), value)), + ) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::BTreeMap; + + use rand::seq::SliceRandom; + use rand::Rng; + + use near_primitives::hash::CryptoHash; + + use crate::test_utils::{create_tries, gen_changes, simplify_changes, test_populate_trie}; + use crate::Trie; + + #[test] + fn test_iterator() { + let mut rng = rand::thread_rng(); + for _ in 0..100 { + let tries = create_tries(); + let trie = tries.get_trie_for_shard(0); + let trie_changes = gen_changes(&mut rng, 10); + let trie_changes = simplify_changes(&trie_changes); + + let mut map = BTreeMap::new(); + for (key, value) in trie_changes.iter() { + if let Some(value) = value { + map.insert(key.clone(), value.clone()); + } + } + let state_root = + test_populate_trie(&tries, &Trie::empty_root(), 0, trie_changes.clone()); + + { + let result1: Vec<_> = trie.iter(&state_root).unwrap().map(Result::unwrap).collect(); + let result2: Vec<_> = map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + assert_eq!(result1, result2); + } + test_seek(&trie, &map, &state_root, &[]); + + for (seek_key, _) in trie_changes.iter() { + test_seek(&trie, &map, &state_root, &seek_key); + } + for _ in 0..20 { + let alphabet = &b"abcdefgh"[0..rng.gen_range(2, 8)]; + let key_length = rng.gen_range(1, 8); + let seek_key: Vec = + (0..key_length).map(|_| alphabet.choose(&mut rng).unwrap().clone()).collect(); + test_seek(&trie, &map, &state_root, &seek_key); } } } + + fn test_seek( + trie: &Trie, + map: &BTreeMap, Vec>, + state_root: &CryptoHash, + seek_key: &[u8], + ) { + let mut iterator = trie.iter(&state_root).unwrap(); + iterator.seek(&seek_key).unwrap(); + let result1: Vec<_> = iterator.map(Result::unwrap).take(5).collect(); + let result2: Vec<_> = + map.range(seek_key.to_vec()..).map(|(k, v)| (k.clone(), v.clone())).take(5).collect(); + assert_eq!(result1, result2); + } } diff --git a/core/store/src/trie/mod.rs b/core/store/src/trie/mod.rs index b95f91b3753..beafa6ee529 100644 --- a/core/store/src/trie/mod.rs +++ b/core/store/src/trie/mod.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::cmp::Ordering; use std::collections::HashMap; use std::convert::TryFrom; @@ -6,7 +7,6 @@ use std::io::{Cursor, Read, Write}; use std::sync::Arc; use borsh::{BorshDeserialize, BorshSerialize}; - use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use near_primitives::challenge::PartialState; @@ -22,7 +22,6 @@ use crate::trie::trie_storage::{ }; pub(crate) use crate::trie::trie_storage::{TrieCache, TrieCachingStorage}; use crate::StorageError; -use std::cell::RefCell; mod insert_delete; pub mod iterator; @@ -63,6 +62,15 @@ enum NodeHandle { Hash(CryptoHash), } +impl NodeHandle { + fn unwrap_hash(&self) -> &CryptoHash { + match self { + Self::Hash(hash) => hash, + Self::InMemory(_) => unreachable!(), + } + } +} + #[derive(Clone, Hash, Debug)] enum ValueHandle { InMemory(StorageValueHandle), @@ -424,7 +432,7 @@ pub struct Trie { /// Having old_root and values in deletions allows to apply TrieChanges in reverse /// /// StoreUpdate are the changes from current state refcount to refcount + delta. -#[derive(BorshSerialize, BorshDeserialize, Clone)] +#[derive(BorshSerialize, BorshDeserialize, Clone, PartialEq, Eq, Debug)] pub struct TrieChanges { pub old_root: StateRoot, pub new_root: StateRoot, @@ -673,7 +681,7 @@ impl Trie { } } - fn convert_to_insertions_and_deletions( + pub(crate) fn convert_to_insertions_and_deletions( changes: HashMap, i32)>, ) -> (Vec<(CryptoHash, Vec, u32)>, Vec<(CryptoHash, Vec, u32)>) { let mut deletions = Vec::new(); @@ -725,12 +733,12 @@ impl Trie { mod tests { use rand::Rng; + use crate::db::DBCol::ColState; use crate::test_utils::{ create_test_store, create_tries, gen_changes, simplify_changes, test_populate_trie, }; use super::*; - use crate::db::DBCol::ColState; type TrieChanges = Vec<(Vec, Option>)>; diff --git a/core/store/src/trie/state_parts.rs b/core/store/src/trie/state_parts.rs index ce2688c8da8..d784e511b8c 100644 --- a/core/store/src/trie/state_parts.rs +++ b/core/store/src/trie/state_parts.rs @@ -1,13 +1,11 @@ -use std::cmp::min; use std::collections::HashMap; use near_primitives::challenge::PartialState; use near_primitives::hash::CryptoHash; use near_primitives::types::StateRoot; -use crate::trie::iterator::CrumbStatus; use crate::trie::nibble_slice::NibbleSlice; -use crate::trie::{NodeHandle, RawTrieNodeWithSize, TrieNode, TrieNodeWithSize, ValueHandle}; +use crate::trie::{NodeHandle, RawTrieNodeWithSize, TrieNode, TrieNodeWithSize}; use crate::{PartialStorage, StorageError, Trie, TrieChanges, TrieIterator}; impl Trie { @@ -27,13 +25,9 @@ impl Trie { ) -> Result { assert!(part_id < num_parts); assert!(self.storage.as_caching_storage().is_some()); - let root_node = self.retrieve_node(&state_root)?; - let total_size = root_node.memory_usage; - let size_start = (total_size + num_parts - 1) / num_parts * part_id; - let size_end = min((total_size + num_parts - 1) / num_parts * (part_id + 1), total_size); let with_recording = self.recording_reads(); - with_recording.visit_nodes_for_size_range(&state_root, size_start, size_end)?; + with_recording.visit_nodes_for_state_part(&state_root, part_id, num_parts)?; let recorded = with_recording.recorded_storage().unwrap(); let trie_nodes = recorded.nodes; @@ -48,37 +42,48 @@ impl Trie { /// /// Creating a StatePart takes all these nodes, validating a StatePart checks that it has the /// right set of nodes. - fn visit_nodes_for_size_range( + fn visit_nodes_for_state_part( &self, root_hash: &CryptoHash, - size_start: u64, - size_end: u64, + part_id: u64, + num_parts: u64, ) -> Result<(), StorageError> { - let root_node = self.retrieve_node(&root_hash)?; - let path_begin = self.find_path(&root_node, size_start)?; - let path_end = self.find_path(&root_node, size_end)?; - - let mut iterator = TrieIterator::new(&self, root_hash)?; - let path_begin_encoded = NibbleSlice::encode_nibbles(&path_begin, false); - iterator.seek_nibble_slice(NibbleSlice::from_encoded(&path_begin_encoded[..]).0)?; - loop { - match iterator.next() { - None => break, - Some(Err(e)) => { - return Err(e); - } - Some(Ok(_)) => { - // The last iteration actually reads a value we don't need. - } - } - // TODO #1603 this is bad for large keys - if iterator.key_nibbles >= path_end { - break; + let path_begin = self.find_path_for_part_boundary(root_hash, part_id, num_parts)?; + let path_end = self.find_path_for_part_boundary(root_hash, part_id + 1, num_parts)?; + let mut iterator = self.iter(&root_hash)?; + iterator.visit_nodes_interval(&path_begin, &path_end)?; + + // Extra nodes for compatibility with the previous version of computing state parts + if part_id + 1 != num_parts { + let mut iterator = TrieIterator::new(&self, root_hash)?; + let path_end_encoded = NibbleSlice::encode_nibbles(&path_end, false); + iterator.seek_nibble_slice(NibbleSlice::from_encoded(&path_end_encoded[..]).0)?; + if let Some(item) = iterator.next() { + item?; } } + Ok(()) } + /// Part part_id has nodes with paths [ path(part_id) .. path(part_id + 1) ) + /// path is returned as nibbles, last path is vec![16], previous paths end in nodes + fn find_path_for_part_boundary( + &self, + state_root: &StateRoot, + part_id: u64, + num_parts: u64, + ) -> Result, StorageError> { + assert!(part_id <= num_parts); + if part_id == num_parts { + return Ok(vec![16]); + } + let root_node = self.retrieve_node(&state_root)?; + let total_size = root_node.memory_usage; + let size_start = (total_size + num_parts - 1) / num_parts * part_id; + self.find_path(&root_node, size_start) + } + fn find_child( &self, size_start: u64, @@ -154,20 +159,16 @@ impl Trie { state_root: &StateRoot, part_id: u64, num_parts: u64, - trie_nodes: &PartialState, + trie_nodes: PartialState, ) -> Result<(), StorageError> { assert!(part_id < num_parts); - let trie = Trie::from_recorded_storage(PartialStorage { nodes: trie_nodes.clone() }); + let num_nodes = trie_nodes.0.len(); + let trie = Trie::from_recorded_storage(PartialStorage { nodes: trie_nodes }); - let root_node = trie.retrieve_node(&state_root)?; - let total_size = root_node.memory_usage; - let size_start = (total_size + num_parts - 1) / num_parts * part_id; - let size_end = min((total_size + num_parts - 1) / num_parts * (part_id + 1), total_size); - - trie.visit_nodes_for_size_range(&state_root, size_start, size_end)?; + trie.visit_nodes_for_state_part(&state_root, part_id, num_parts)?; let storage = trie.storage.as_partial_storage().unwrap(); - if storage.visited_nodes.borrow().len() != trie_nodes.0.len() { + if storage.visited_nodes.borrow().len() != num_nodes { // TODO #1603 not actually TrieNodeMissing. // The error is that the proof has more nodes than needed. return Err(StorageError::TrieNodeMissing); @@ -175,125 +176,33 @@ impl Trie { Ok(()) } - /// on_enter is applied for nodes as well as values - fn traverse_all_nodes Result<(), StorageError>>( - &self, - root: &CryptoHash, - mut on_enter: F, - ) -> Result<(), StorageError> { - if root == &CryptoHash::default() { - return Ok(()); - } - let mut stack: Vec<(CryptoHash, TrieNodeWithSize, CrumbStatus)> = Vec::new(); - let root_node = self.retrieve_node(root)?; - stack.push((*root, root_node, CrumbStatus::Entering)); - while let Some((hash, node, position)) = stack.pop() { - if let CrumbStatus::Entering = position { - on_enter(&hash)?; - } - match &node.node { - TrieNode::Empty => { - continue; - } - TrieNode::Leaf(_, value) => { - match value { - ValueHandle::HashAndSize(_, hash) => { - on_enter(hash)?; - } - ValueHandle::InMemory(_) => unreachable!("only possible while mutating"), - } - continue; - } - TrieNode::Branch(children, value) => match position { - CrumbStatus::Entering => { - match value { - Some(ValueHandle::HashAndSize(_, hash)) => { - on_enter(hash)?; - } - _ => {} - } - stack.push((hash, node, CrumbStatus::AtChild(0))); - continue; - } - CrumbStatus::AtChild(mut i) => { - while i < 16 { - if let Some(NodeHandle::Hash(_h)) = children[i].as_ref() { - break; - } - i += 1; - } - if i < 16 { - if let Some(NodeHandle::Hash(h)) = children[i].clone() { - let child = self.retrieve_node(&h)?; - stack.push((hash, node, CrumbStatus::AtChild(i + 1))); - stack.push((h, child, CrumbStatus::Entering)); - } else { - stack.push((hash, node, CrumbStatus::Exiting)); - } - } else { - stack.push((hash, node, CrumbStatus::Exiting)); - } - } - CrumbStatus::Exiting => { - continue; - } - CrumbStatus::At => { - continue; - } - }, - TrieNode::Extension(_key, child) => { - if let CrumbStatus::Entering = position { - match child.clone() { - NodeHandle::InMemory(_) => unreachable!("only possible while mutating"), - NodeHandle::Hash(h) => { - let child = self.retrieve_node(&h)?; - stack.push((hash, node, CrumbStatus::Exiting)); - stack.push((h, child, CrumbStatus::Entering)); - } - } - } - } - } - } - Ok(()) - } - - /// Combines all parts and returns TrieChanges that can be applied to storage. - /// - /// # Input - /// parts[i] has trie nodes for part i - /// - /// # Errors - /// StorageError if data is inconsistent. Should never happen if each part was validated. - pub fn combine_state_parts( + /// Returns the storage changes for the state part. + /// Writing all storage changes gives the complete trie. + pub fn apply_state_part( state_root: &StateRoot, - parts: &Vec>>, + part_id: u64, + num_parts: u64, + part: Vec>, ) -> Result { - let nodes = parts - .iter() - .map(|part| part.iter()) - .flatten() - .map(|data| data.to_vec()) - .collect::>(); - let trie = Trie::from_recorded_storage(PartialStorage { nodes: PartialState(nodes) }); - let mut insertions = , u32)>>::new(); - trie.traverse_all_nodes(&state_root, |hash| { - if let Some((_bytes, rc)) = insertions.get_mut(hash) { - *rc += 1; - } else { - let bytes = trie.storage.retrieve_raw_bytes(hash)?; - insertions.insert(*hash, (bytes, 1)); - } - Ok(()) - })?; - let mut insertions = - insertions.into_iter().map(|(k, (v, rc))| (k, v, rc)).collect::>(); - insertions.sort(); + if state_root == &CryptoHash::default() { + return Ok(TrieChanges::empty(CryptoHash::default())); + } + let trie = Trie::from_recorded_storage(PartialStorage { nodes: PartialState(part) }); + let path_begin = trie.find_path_for_part_boundary(state_root, part_id, num_parts)?; + let path_end = trie.find_path_for_part_boundary(state_root, part_id + 1, num_parts)?; + let mut iterator = TrieIterator::new(&trie, state_root)?; + let hashes = iterator.visit_nodes_interval(&path_begin, &path_end)?; + let mut map = HashMap::new(); + for hash in hashes { + let value = trie.retrieve_raw_bytes(&hash)?; + map.entry(hash).or_insert_with(|| (value, 0)).1 += 1; + } + let (insertions, deletions) = Trie::convert_to_insertions_and_deletions(map); Ok(TrieChanges { - old_root: Default::default(), + old_root: CryptoHash::default(), new_root: *state_root, insertions, - deletions: vec![], + deletions, }) } @@ -311,19 +220,206 @@ impl Trie { mod tests { use std::collections::HashMap; + use rand::prelude::ThreadRng; use rand::Rng; use near_primitives::hash::{hash, CryptoHash}; use crate::test_utils::{create_tries, gen_changes, test_populate_trie}; + use crate::trie::iterator::CrumbStatus; + use crate::trie::ValueHandle; use super::*; - use rand::prelude::ThreadRng; + + impl Trie { + /// Combines all parts and returns TrieChanges that can be applied to storage. + /// + /// # Input + /// parts[i] has trie nodes for part i + /// + /// # Errors + /// StorageError if data is inconsistent. Should never happen if each part was validated. + pub fn combine_state_parts_naive( + state_root: &StateRoot, + parts: &Vec>>, + ) -> Result { + let nodes = parts + .iter() + .map(|part| part.iter()) + .flatten() + .map(|data| data.to_vec()) + .collect::>(); + let trie = Trie::from_recorded_storage(PartialStorage { nodes: PartialState(nodes) }); + let mut insertions = , u32)>>::new(); + trie.traverse_all_nodes(&state_root, |hash| { + if let Some((_bytes, rc)) = insertions.get_mut(hash) { + *rc += 1; + } else { + let bytes = trie.storage.retrieve_raw_bytes(hash)?; + insertions.insert(*hash, (bytes, 1)); + } + Ok(()) + })?; + let mut insertions = + insertions.into_iter().map(|(k, (v, rc))| (k, v, rc)).collect::>(); + insertions.sort(); + Ok(TrieChanges { + old_root: Default::default(), + new_root: *state_root, + insertions, + deletions: vec![], + }) + } + + /// on_enter is applied for nodes as well as values + fn traverse_all_nodes Result<(), StorageError>>( + &self, + root: &CryptoHash, + mut on_enter: F, + ) -> Result<(), StorageError> { + if root == &CryptoHash::default() { + return Ok(()); + } + let mut stack: Vec<(CryptoHash, TrieNodeWithSize, CrumbStatus)> = Vec::new(); + let root_node = self.retrieve_node(root)?; + stack.push((*root, root_node, CrumbStatus::Entering)); + while let Some((hash, node, position)) = stack.pop() { + if let CrumbStatus::Entering = position { + on_enter(&hash)?; + } + match &node.node { + TrieNode::Empty => { + continue; + } + TrieNode::Leaf(_, value) => { + match value { + ValueHandle::HashAndSize(_, hash) => { + on_enter(hash)?; + } + ValueHandle::InMemory(_) => { + unreachable!("only possible while mutating") + } + } + continue; + } + TrieNode::Branch(children, value) => match position { + CrumbStatus::Entering => { + match value { + Some(ValueHandle::HashAndSize(_, hash)) => { + on_enter(hash)?; + } + _ => {} + } + stack.push((hash, node, CrumbStatus::AtChild(0))); + continue; + } + CrumbStatus::AtChild(mut i) => { + while i < 16 { + if let Some(NodeHandle::Hash(_h)) = children[i].as_ref() { + break; + } + i += 1; + } + if i < 16 { + if let Some(NodeHandle::Hash(h)) = children[i].clone() { + let child = self.retrieve_node(&h)?; + stack.push((hash, node, CrumbStatus::AtChild(i + 1))); + stack.push((h, child, CrumbStatus::Entering)); + } else { + stack.push((hash, node, CrumbStatus::Exiting)); + } + } else { + stack.push((hash, node, CrumbStatus::Exiting)); + } + } + CrumbStatus::Exiting => { + continue; + } + CrumbStatus::At => { + continue; + } + }, + TrieNode::Extension(_key, child) => { + if let CrumbStatus::Entering = position { + match child.clone() { + NodeHandle::InMemory(_) => { + unreachable!("only possible while mutating") + } + NodeHandle::Hash(h) => { + let child = self.retrieve_node(&h)?; + stack.push((hash, node, CrumbStatus::Exiting)); + stack.push((h, child, CrumbStatus::Entering)); + } + } + } + } + } + } + Ok(()) + } + + fn visit_nodes_for_size_range_old( + &self, + root_hash: &CryptoHash, + size_start: u64, + size_end: u64, + ) -> Result<(), StorageError> { + let root_node = self.retrieve_node(&root_hash)?; + let path_begin = self.find_path(&root_node, size_start)?; + let path_end = self.find_path(&root_node, size_end)?; + + let mut iterator = TrieIterator::new(&self, root_hash)?; + let path_begin_encoded = NibbleSlice::encode_nibbles(&path_begin, false); + iterator.seek_nibble_slice(NibbleSlice::from_encoded(&path_begin_encoded[..]).0)?; + loop { + match iterator.next() { + None => break, + Some(Err(e)) => { + return Err(e); + } + Some(Ok(_item)) => { + // The last iteration actually reads a value we don't need. + } + } + // TODO #1603 this is bad for large keys + if iterator.key_nibbles >= path_end { + break; + } + } + Ok(()) + } + + pub fn get_trie_nodes_for_part_old( + &self, + part_id: u64, + num_parts: u64, + state_root: &StateRoot, + ) -> Result { + assert!(part_id < num_parts); + assert!(self.storage.as_caching_storage().is_some()); + let root_node = self.retrieve_node(&state_root)?; + let total_size = root_node.memory_usage; + let size_start = (total_size + num_parts - 1) / num_parts * part_id; + let size_end = + std::cmp::min((total_size + num_parts - 1) / num_parts * (part_id + 1), total_size); + + let with_recording = self.recording_reads(); + with_recording.visit_nodes_for_size_range_old(&state_root, size_start, size_end)?; + let recorded = with_recording.recorded_storage().unwrap(); + + let trie_nodes = recorded.nodes; + + Ok(trie_nodes) + } + } #[test] fn test_combine_empty_trie_parts() { let state_root = StateRoot::default(); - let _ = Trie::combine_state_parts(&state_root, &vec![]).unwrap(); + let _ = Trie::combine_state_parts_naive(&state_root, &vec![]).unwrap(); + let _ = + Trie::validate_trie_nodes_for_part(&state_root, 0, 1, PartialState(vec![])).unwrap(); + let _ = Trie::apply_state_part(&state_root, 0, 1, vec![]).unwrap(); } fn construct_trie_for_big_parts_1( @@ -438,26 +534,35 @@ mod tests { run_test_parts_not_huge(construct_trie_for_big_parts_2, 100_000); } + fn merge_trie_changes(changes: Vec) -> TrieChanges { + if changes.is_empty() { + return TrieChanges::empty(CryptoHash::default()); + } + let new_root = changes[0].new_root; + let mut map = HashMap::new(); + for changes_set in changes { + assert!(changes_set.deletions.is_empty(), "state parts only have insertions"); + for (key, value, rc) in changes_set.insertions { + map.entry(key).or_insert_with(|| (value, 0)).1 += rc as i32; + } + for (key, value, rc) in changes_set.deletions { + map.entry(key).or_insert_with(|| (value, 0)).1 -= rc as i32; + } + } + let (insertions, deletions) = Trie::convert_to_insertions_and_deletions(map); + TrieChanges { old_root: Default::default(), new_root, insertions, deletions } + } + #[test] - fn test_parts() { + fn test_combine_state_parts() { let mut rng = rand::thread_rng(); - for _ in 0..20 { + for _ in 0..2000 { let tries = create_tries(); let trie = tries.get_trie_for_shard(0); - let trie_changes = gen_changes(&mut rng, 500); - + let trie_changes = gen_changes(&mut rng, 20); let state_root = test_populate_trie(&tries, &Trie::empty_root(), 0, trie_changes.clone()); let root_memory_usage = trie.retrieve_root_node(&state_root).unwrap().memory_usage; - for _ in 0..100 { - // Test that creating and validating are consistent - let num_parts = rng.gen_range(1, 10); - let part_id = rng.gen_range(0, num_parts); - let trie_nodes = - trie.get_trie_nodes_for_part(part_id, num_parts, &state_root).unwrap(); - Trie::validate_trie_nodes_for_part(&state_root, part_id, num_parts, &trie_nodes) - .expect("validate ok"); - } { // Test that combining all parts gets all nodes @@ -468,7 +573,8 @@ mod tests { }) .collect::>(); - let trie_changes = Trie::combine_state_parts(&state_root, &parts).unwrap(); + let trie_changes = check_combine_state_parts(&state_root, num_parts, &parts); + let mut nodes = >>::new(); let sizes_vec = parts .iter() @@ -483,25 +589,71 @@ mod tests { let all_nodes = nodes.into_iter().map(|(_hash, node)| node).collect::>(); assert_eq!(all_nodes.len(), trie_changes.insertions.len()); let size_of_all = all_nodes.iter().map(|node| node.len()).sum::(); - Trie::validate_trie_nodes_for_part( - &state_root, - 0, - 1, - &PartialState(all_nodes.clone()), - ) - .expect("validate ok"); + let num_nodes = all_nodes.len(); + Trie::validate_trie_nodes_for_part(&state_root, 0, 1, PartialState(all_nodes)) + .expect("validate ok"); let sum_of_sizes = sizes_vec.iter().sum::(); // Manually check that sizes are reasonable println!("------------------------------"); - println!("Number of nodes: {:?}", all_nodes.len()); + println!("Number of nodes: {:?}", num_nodes); println!("Sizes of parts: {:?}", sizes_vec); println!( "All nodes size: {:?}, sum_of_sizes: {:?}, memory_usage: {:?}", size_of_all, sum_of_sizes, root_memory_usage ); // borsh serialize should be about this size - assert!(size_of_all + 8 * all_nodes.len() <= root_memory_usage as usize); + assert!(size_of_all + 8 * num_nodes <= root_memory_usage as usize); + } + } + } + + fn check_combine_state_parts( + state_root: &CryptoHash, + num_parts: u64, + parts: &Vec>>, + ) -> TrieChanges { + let trie_changes = Trie::combine_state_parts_naive(&state_root, &parts).unwrap(); + + let trie_changes_new = { + let changes = (0..num_parts) + .map(|part_id| { + Trie::apply_state_part( + &state_root, + part_id, + num_parts, + parts[part_id as usize].clone(), + ) + .unwrap() + }) + .collect::>(); + merge_trie_changes(changes) + }; + assert_eq!(trie_changes, trie_changes_new); + trie_changes + } + + #[test] + fn test_get_trie_nodes_for_part() { + let mut rng = rand::thread_rng(); + for _ in 0..20 { + let tries = create_tries(); + let trie = tries.get_trie_for_shard(0); + let trie_changes = gen_changes(&mut rng, 10); + + let state_root = + test_populate_trie(&tries, &Trie::empty_root(), 0, trie_changes.clone()); + for _ in 0..10 { + // Test that creating and validating are consistent + let num_parts = rng.gen_range(1, 10); + let part_id = rng.gen_range(0, num_parts); + let trie_nodes = + trie.get_trie_nodes_for_part(part_id, num_parts, &state_root).unwrap(); + let trie_nodes2 = + trie.get_trie_nodes_for_part_old(part_id, num_parts, &state_root).unwrap(); + assert_eq!(trie_nodes, trie_nodes2); + Trie::validate_trie_nodes_for_part(&state_root, part_id, num_parts, trie_nodes) + .expect("validate ok"); } } } diff --git a/neard/src/runtime.rs b/neard/src/runtime.rs index 8a75a8d0d82..0d2d4baa174 100644 --- a/neard/src/runtime.rs +++ b/neard/src/runtime.rs @@ -1339,35 +1339,30 @@ impl RuntimeAdapter for NightshadeRuntime { ) -> bool { assert!(part_id < num_parts); match BorshDeserialize::try_from_slice(data) { - Ok(trie_nodes) => match Trie::validate_trie_nodes_for_part( - state_root, - part_id, - num_parts, - &trie_nodes, - ) { - Ok(_) => true, - // Storage error should not happen - Err(_) => false, - }, + Ok(trie_nodes) => { + match Trie::validate_trie_nodes_for_part(state_root, part_id, num_parts, trie_nodes) + { + Ok(_) => true, + // Storage error should not happen + Err(_) => false, + } + } // Deserialization error means we've got the data from malicious peer Err(_) => false, } } - fn confirm_state( + fn apply_state_part( &self, shard_id: ShardId, state_root: &StateRoot, - data: &Vec>, + part_id: u64, + num_parts: u64, + data: &[u8], ) -> Result<(), Error> { - let mut parts = vec![]; - for part in data { - parts.push( - BorshDeserialize::try_from_slice(part) - .expect("Part was already validated earlier, so could never fail here"), - ); - } - let trie_changes = Trie::combine_state_parts(&state_root, &parts) + let part = BorshDeserialize::try_from_slice(data) + .expect("Part was already validated earlier, so could never fail here"); + let trie_changes = Trie::apply_state_part(&state_root, part_id, num_parts, part) .expect("combine_state_parts is guaranteed to succeed when each part is valid"); let tries = self.get_tries(); let (store_update, _) = @@ -2269,7 +2264,7 @@ mod test { assert!(!new_env.runtime.validate_state_root_node(&root_node_wrong, &env.state_roots[0])); assert!(!new_env.runtime.validate_state_part(&StateRoot::default(), 0, 1, &state_part)); new_env.runtime.validate_state_part(&env.state_roots[0], 0, 1, &state_part); - new_env.runtime.confirm_state(0, &env.state_roots[0], &vec![state_part]).unwrap(); + new_env.runtime.apply_state_part(0, &env.state_roots[0], 0, 1, &state_part).unwrap(); new_env.state_roots[0] = env.state_roots[0].clone(); for _ in 3..=5 { new_env.step_default(vec![]);