diff --git a/chain/chain/src/chain.rs b/chain/chain/src/chain.rs index e7af6450e02..1be35c74e95 100644 --- a/chain/chain/src/chain.rs +++ b/chain/chain/src/chain.rs @@ -92,13 +92,15 @@ use near_primitives::views::{ }; use near_store::config::StateSnapshotType; use near_store::flat::{store_helper, FlatStorageReadyStatus, FlatStorageStatus}; -use near_store::get_genesis_state_roots; +use near_store::trie::mem::resharding::RetainMode; use near_store::DBCol; +use near_store::{get_genesis_state_roots, PartialStorage}; use node_runtime::bootstrap_congestion_info; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt::{Debug, Formatter}; use std::num::NonZeroUsize; +use std::str::FromStr; use std::sync::Arc; use time::ext::InstantExt as _; use tracing::{debug, debug_span, error, info, warn, Span}; @@ -1847,6 +1849,88 @@ impl Chain { }); } + /// If shard layout changes after the given block, creates temporary + /// memtries for new shards to be able to process them in the next epoch. + /// Note this doesn't complete resharding, proper memtries are to be + /// created later. + fn process_memtrie_resharding_storage_update( + &mut self, + block: &Block, + shard_uid: ShardUId, + ) -> Result<(), Error> { + let block_hash = block.hash(); + let block_height = block.header().height(); + let prev_hash = block.header().prev_hash(); + if !self.epoch_manager.will_shard_layout_change(prev_hash)? { + return Ok(()); + } + + let next_epoch_id = self.epoch_manager.get_next_epoch_id_from_prev_block(prev_hash)?; + let next_shard_layout = self.epoch_manager.get_shard_layout(&next_epoch_id)?; + let children_shard_uids = + next_shard_layout.get_children_shards_uids(shard_uid.shard_id()).unwrap(); + + // Hack to ensure this logic is not applied before ReshardingV3. + // TODO(#12019): proper logic. + if next_shard_layout.version() < 3 || children_shard_uids.len() == 1 { + return Ok(()); + } + assert_eq!(children_shard_uids.len(), 2); + + let chunk_extra = self.get_chunk_extra(block_hash, &shard_uid)?; + let tries = self.runtime_adapter.get_tries(); + let Some(mem_tries) = tries.get_mem_tries(shard_uid) else { + // TODO(#12019): what if node doesn't have memtrie? just pause + // processing? + error!( + "Memtrie not loaded. Cannot process memtrie resharding storage + update for block {:?}, shard {:?}", + block_hash, shard_uid + ); + return Err(Error::Other("Memtrie not loaded".to_string())); + }; + + // TODO(#12019): take proper boundary account. + let boundary_account = AccountId::from_str("boundary.near").unwrap(); + + // TODO(#12019): leave only tracked shards. + for (new_shard_uid, retain_mode) in [ + (children_shard_uids[0], RetainMode::Left), + (children_shard_uids[1], RetainMode::Right), + ] { + let mut mem_tries = mem_tries.write().unwrap(); + let mem_trie_update = mem_tries.update(*chunk_extra.state_root(), true)?; + + let (trie_changes, _) = + mem_trie_update.retain_split_shard(boundary_account.clone(), retain_mode); + let partial_state = PartialState::default(); + let partial_storage = PartialStorage { nodes: partial_state }; + let mem_changes = trie_changes.mem_trie_changes.as_ref().unwrap(); + let new_state_root = mem_tries.apply_memtrie_changes(block_height, mem_changes); + // TODO(#12019): set all fields of `ChunkExtra`. Consider stronger + // typing. Clarify where it should happen when `State` and + // `FlatState` update is implemented. + let mut child_chunk_extra = ChunkExtra::clone(&chunk_extra); + *child_chunk_extra.state_root_mut() = new_state_root; + + let mut chain_store_update = ChainStoreUpdate::new(&mut self.chain_store); + chain_store_update.save_chunk_extra(block_hash, &new_shard_uid, child_chunk_extra); + chain_store_update.save_state_transition_data( + *block_hash, + new_shard_uid.shard_id(), + Some(partial_storage), + CryptoHash::default(), + ); + chain_store_update.commit()?; + + let mut store_update = self.chain_store.store().store_update(); + tries.apply_insertions(&trie_changes, new_shard_uid, &mut store_update); + store_update.commit()?; + } + + Ok(()) + } + #[tracing::instrument(level = "debug", target = "chain", "postprocess_block_only", skip_all)] fn postprocess_block_only( &mut self, @@ -1936,20 +2020,13 @@ impl Chain { true, ); let care_about_shard_this_or_next_epoch = care_about_shard || will_care_about_shard; + let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &epoch_id).unwrap(); if care_about_shard_this_or_next_epoch { - let shard_uid = self.epoch_manager.shard_id_to_uid(shard_id, &epoch_id).unwrap(); shards_cares_this_or_next_epoch.push(shard_uid); } - // Update flat storage head to be the last final block. Note that this update happens - // in a separate db transaction from the update from block processing. This is intentional - // because flat_storage need to be locked during the update of flat head, otherwise - // flat_storage is in an inconsistent state that could be accessed by the other - // apply chunks processes. This means, the flat head is not always the same as - // the last final block on chain, which is OK, because in the flat storage implementation - // we don't assume that. - let need_flat_storage_update = if is_caught_up { - // If we already caught up this epoch, then flat storage exists for both shards which we already track + let need_storage_update = if is_caught_up { + // If we already caught up this epoch, then storage exists for both shards which we already track // and shards which will be tracked in next epoch, so we can update them. care_about_shard_this_or_next_epoch } else { @@ -1957,9 +2034,19 @@ impl Chain { // during catchup of this block. care_about_shard }; - tracing::debug!(target: "chain", shard_id, need_flat_storage_update, "Updating flat storage"); - - if need_flat_storage_update { + tracing::debug!(target: "chain", shard_id, need_storage_update, "Updating storage"); + + if need_storage_update { + // TODO(#12019): consider adding to catchup flow. + self.process_memtrie_resharding_storage_update(&block, shard_uid)?; + + // Update flat storage head to be the last final block. Note that this update happens + // in a separate db transaction from the update from block processing. This is intentional + // because flat_storage need to be locked during the update of flat head, otherwise + // flat_storage is in an inconsistent state that could be accessed by the other + // apply chunks processes. This means, the flat head is not always the same as + // the last final block on chain, which is OK, because in the flat storage implementation + // we don't assume that. self.update_flat_storage_and_memtrie(&block, shard_id)?; } } diff --git a/core/store/src/trie/mem/mod.rs b/core/store/src/trie/mem/mod.rs index 01c9df427ef..03fa125e495 100644 --- a/core/store/src/trie/mem/mod.rs +++ b/core/store/src/trie/mem/mod.rs @@ -9,6 +9,7 @@ pub mod mem_tries; pub mod metrics; pub mod node; mod parallel_loader; +pub mod resharding; pub mod updating; /// Check this, because in the code we conveniently assume usize is 8 bytes. diff --git a/core/store/src/trie/mem/resharding.rs b/core/store/src/trie/mem/resharding.rs new file mode 100644 index 00000000000..281e828abba --- /dev/null +++ b/core/store/src/trie/mem/resharding.rs @@ -0,0 +1,248 @@ +use crate::{NibbleSlice, TrieChanges}; + +use super::arena::ArenaMemory; +use super::updating::{MemTrieUpdate, OldOrUpdatedNodeId, TrieAccesses, UpdatedMemTrieNode}; +use itertools::Itertools; +use near_primitives::types::AccountId; +use std::ops::Range; + +/// Whether to retain left or right part of trie after shard split. +pub enum RetainMode { + Left, + Right, +} + +/// Decision on the subtree exploration. +#[derive(Debug)] +enum RetainDecision { + /// Retain the whole subtree. + RetainAll, + /// The whole subtree is not retained. + DiscardAll, + /// Descend into all child subtrees. + Descend, +} + +impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { + /// Splits the trie, separating entries by the boundary account. + /// Leaves the left or right part of the trie, depending on the retain mode. + /// + /// Returns the changes to be applied to in-memory trie and the proof of + /// the split operation. Doesn't modifies trie itself, it's a caller's + /// responsibility to apply the changes. + pub fn retain_split_shard( + self, + _boundary_account: AccountId, + _retain_mode: RetainMode, + ) -> (TrieChanges, TrieAccesses) { + // TODO(#12074): generate intervals in nibbles. + + self.retain_multi_range(&[]) + } + + /// Retains keys belonging to any of the ranges given in `intervals` from + /// the trie. + /// + /// Returns changes to be applied to in-memory trie and proof of the + /// retain operation. + fn retain_multi_range(mut self, intervals: &[Range>]) -> (TrieChanges, TrieAccesses) { + debug_assert!(intervals.iter().all(|range| range.start < range.end)); + let intervals_nibbles = intervals + .iter() + .map(|range| { + NibbleSlice::new(&range.start).iter().collect_vec() + ..NibbleSlice::new(&range.end).iter().collect_vec() + }) + .collect_vec(); + + // TODO(#12074): consider handling the case when no changes are made. + // TODO(#12074): restore proof as well. + self.retain_multi_range_recursive(0, vec![], &intervals_nibbles); + self.to_trie_changes() + } + + /// Recursive implementation of the algorithm of retaining keys belonging to + /// any of the ranges given in `intervals` from the trie. All changes are + /// applied in `updated_nodes`. + /// + /// `node_id` is the root of subtree being explored. + /// `key_nibbles` is the key corresponding to `root`. + /// `intervals_nibbles` is the list of ranges to be retained. + fn retain_multi_range_recursive( + &mut self, + node_id: usize, + key_nibbles: Vec, + intervals_nibbles: &[Range>], + ) { + let decision = retain_decision(&key_nibbles, intervals_nibbles); + match decision { + RetainDecision::RetainAll => return, + RetainDecision::DiscardAll => { + let _ = self.take_node(node_id); + self.place_node(node_id, UpdatedMemTrieNode::Empty); + return; + } + RetainDecision::Descend => { + // We need to descend into all children. The logic follows below. + } + } + + let node = self.take_node(node_id); + match node { + UpdatedMemTrieNode::Empty => { + // Nowhere to descend. + self.place_node(node_id, UpdatedMemTrieNode::Empty); + return; + } + UpdatedMemTrieNode::Leaf { extension, value } => { + let full_key_nibbles = + [key_nibbles, NibbleSlice::from_encoded(&extension).0.iter().collect_vec()] + .concat(); + if !intervals_nibbles.iter().any(|interval| interval.contains(&full_key_nibbles)) { + self.place_node(node_id, UpdatedMemTrieNode::Empty); + } else { + self.place_node(node_id, UpdatedMemTrieNode::Leaf { extension, value }); + } + } + UpdatedMemTrieNode::Branch { mut children, mut value } => { + if !intervals_nibbles.iter().any(|interval| interval.contains(&key_nibbles)) { + value = None; + } + + for (i, child) in children.iter_mut().enumerate() { + let Some(old_child_id) = child.take() else { + continue; + }; + + let new_child_id = self.ensure_updated(old_child_id); + let child_key_nibbles = [key_nibbles.clone(), vec![i as u8]].concat(); + self.retain_multi_range_recursive( + new_child_id, + child_key_nibbles, + intervals_nibbles, + ); + if self.updated_nodes[new_child_id] == Some(UpdatedMemTrieNode::Empty) { + *child = None; + } else { + *child = Some(OldOrUpdatedNodeId::Updated(new_child_id)); + } + } + + // TODO(#12074): squash the branch if needed. Consider reusing + // `squash_nodes`. + + self.place_node(node_id, UpdatedMemTrieNode::Branch { children, value }); + } + UpdatedMemTrieNode::Extension { extension, child } => { + let new_child_id = self.ensure_updated(child); + let extension_nibbles = + NibbleSlice::from_encoded(&extension).0.iter().collect_vec(); + let child_key = [key_nibbles, extension_nibbles].concat(); + self.retain_multi_range_recursive(new_child_id, child_key, intervals_nibbles); + + if self.updated_nodes[new_child_id] == Some(UpdatedMemTrieNode::Empty) { + self.place_node(node_id, UpdatedMemTrieNode::Empty); + } else { + self.place_node( + node_id, + UpdatedMemTrieNode::Extension { + extension, + child: OldOrUpdatedNodeId::Updated(new_child_id), + }, + ); + } + } + } + } +} + +/// Based on the key and the intervals, makes decision on the subtree exploration. +fn retain_decision(key: &[u8], intervals: &[Range>]) -> RetainDecision { + let mut should_descend = false; + for interval in intervals { + // If key can be extended to be equal to start or end of the interval, + // its subtree may have keys inside the interval. At the same time, + // it can be extended with bytes which would fall outside the interval. + // + // For example, if key is "a" and interval is "ab".."cd", subtree may + // contain both "aa" which must be excluded and "ac" which must be + // retained. + if interval.start.starts_with(key) || interval.end.starts_with(key) { + should_descend = true; + continue; + } + + // If key is not a prefix of boundaries and falls inside the interval, + // one can show that all the keys in the subtree are also inside the + // interval. + if interval.start.as_slice() <= key && key < interval.end.as_slice() { + return RetainDecision::RetainAll; + } + + // Otherwise, all the keys in the subtree are outside the interval. + } + + if should_descend { + RetainDecision::Descend + } else { + RetainDecision::DiscardAll + } +} + +// TODO(#12074): tests for +// - multiple retain ranges +// - result is empty, or no changes are made +// - removing keys one-by-one gives the same result as corresponding range retain +// - `retain_split_shard` API +// - all results of squashing branch +// - checking not accessing not-inlined nodes +// - proof correctness +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use itertools::Itertools; + use near_primitives::{shard_layout::ShardUId, types::StateRoot}; + + use crate::{ + trie::{ + mem::{iter::MemTrieIterator, mem_tries::MemTries}, + trie_storage::TrieMemoryPartialStorage, + }, + Trie, + }; + + #[test] + /// Applies single range retain to the trie and checks the result. + fn test_retain_single_range() { + let initial_entries = vec![ + (b"alice".to_vec(), vec![1]), + (b"bob".to_vec(), vec![2]), + (b"charlie".to_vec(), vec![3]), + (b"david".to_vec(), vec![4]), + ]; + let retain_range = b"amy".to_vec()..b"david".to_vec(); + let retain_result = vec![(b"bob".to_vec(), vec![2]), (b"charlie".to_vec(), vec![3])]; + + let mut memtries = MemTries::new(ShardUId::single_shard()); + let empty_state_root = StateRoot::default(); + let mut update = memtries.update(empty_state_root, false).unwrap(); + for (key, value) in initial_entries { + update.insert(&key, value); + } + let memtrie_changes = update.to_mem_trie_changes_only(); + let state_root = memtries.apply_memtrie_changes(0, &memtrie_changes); + + let update = memtries.update(state_root, true).unwrap(); + let (mut trie_changes, _) = update.retain_multi_range(&[retain_range]); + let memtrie_changes = trie_changes.mem_trie_changes.take().unwrap(); + let new_state_root = memtries.apply_memtrie_changes(1, &memtrie_changes); + + let state_root_ptr = memtries.get_root(&new_state_root).unwrap(); + let trie = Trie::new(Arc::new(TrieMemoryPartialStorage::default()), new_state_root, None); + let entries = + MemTrieIterator::new(Some(state_root_ptr), &trie).map(|e| e.unwrap()).collect_vec(); + + assert_eq!(entries, retain_result); + } +} diff --git a/core/store/src/trie/mem/updating.rs b/core/store/src/trie/mem/updating.rs index 61b374367df..adfe6504bdc 100644 --- a/core/store/src/trie/mem/updating.rs +++ b/core/store/src/trie/mem/updating.rs @@ -43,7 +43,7 @@ pub enum UpdatedMemTrieNode { } /// Keeps values and internal nodes accessed on updating memtrie. -pub(crate) struct TrieAccesses { +pub struct TrieAccesses { /// Hashes and encoded trie nodes. pub nodes: HashMap>, /// Hashes of accessed values - because values themselves are not @@ -144,12 +144,12 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// Internal function to take a node from the array of updated nodes, setting it /// to None. It is expected that place_node is then called to return the node to /// the same slot. - fn take_node(&mut self, index: UpdatedMemTrieNodeId) -> UpdatedMemTrieNode { + pub(crate) fn take_node(&mut self, index: UpdatedMemTrieNodeId) -> UpdatedMemTrieNode { self.updated_nodes.get_mut(index).unwrap().take().expect("Node taken twice") } /// Does the opposite of take_node; returns the node to the specified ID. - fn place_node(&mut self, index: UpdatedMemTrieNodeId, node: UpdatedMemTrieNode) { + pub(crate) fn place_node(&mut self, index: UpdatedMemTrieNodeId, node: UpdatedMemTrieNode) { assert!(self.updated_nodes[index].is_none(), "Node placed twice"); self.updated_nodes[index] = Some(node); } @@ -191,7 +191,7 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } /// If the ID was old, converts it to an updated one. - fn ensure_updated(&mut self, node: OldOrUpdatedNodeId) -> UpdatedMemTrieNodeId { + pub(crate) fn ensure_updated(&mut self, node: OldOrUpdatedNodeId) -> UpdatedMemTrieNodeId { match node { OldOrUpdatedNodeId::Old(node_id) => self.convert_existing_to_updated(Some(node_id)), OldOrUpdatedNodeId::Updated(node_id) => node_id, @@ -689,13 +689,15 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { } /// For each node in `ordered_nodes`, computes its hash and serialized data. - /// The `ordered_nodes` is expected to come from `post_order_traverse_updated_nodes`, - /// and updated_nodes are indexed by the node IDs in `ordered_nodes`. - fn compute_hashes_and_serialized_nodes( + /// `ordered_nodes` is expected to follow the post-order traversal of the + /// updated nodes. + /// `updated_nodes` must be indexed by the node IDs in `ordered_nodes`. + pub(crate) fn compute_hashes_and_serialized_nodes( + &self, ordered_nodes: &Vec, updated_nodes: &Vec>, - arena: &'a M, ) -> Vec<(UpdatedMemTrieNodeId, CryptoHash, Vec)> { + let memory = self.memory; let mut result = Vec::<(CryptoHash, u64, Vec)>::new(); for _ in 0..updated_nodes.len() { result.push((CryptoHash::default(), 0, Vec::new())); @@ -709,7 +711,7 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { (hash, memory_usage) } OldOrUpdatedNodeId::Old(node_id) => { - let view = node_id.as_ptr(arena).view(); + let view = node_id.as_ptr(memory).view(); (view.node_hash(), view.memory_usage()) } } @@ -780,27 +782,23 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// Converts the changes to memtrie changes. Also returns the list of new nodes inserted, /// in hash and serialized form. - fn to_mem_trie_changes_internal( - shard_uid: String, - arena: &'a M, - updated_nodes: Vec>, - ) -> (MemTrieChanges, Vec<(CryptoHash, Vec)>) { + fn to_mem_trie_changes_internal(self) -> (MemTrieChanges, Vec<(CryptoHash, Vec)>) { MEM_TRIE_NUM_NODES_CREATED_FROM_UPDATES - .with_label_values(&[&shard_uid]) - .inc_by(updated_nodes.len() as u64); + .with_label_values(&[&self.shard_uid]) + .inc_by(self.updated_nodes.len() as u64); let mut ordered_nodes = Vec::new(); - Self::post_order_traverse_updated_nodes(0, &updated_nodes, &mut ordered_nodes); + Self::post_order_traverse_updated_nodes(0, &self.updated_nodes, &mut ordered_nodes); - let nodes_hashes_and_serialized = - Self::compute_hashes_and_serialized_nodes(&ordered_nodes, &updated_nodes, arena); + let hashes_and_serialized_nodes = + self.compute_hashes_and_serialized_nodes(&ordered_nodes, &self.updated_nodes); - let node_ids_with_hashes = nodes_hashes_and_serialized + let node_ids_with_hashes = hashes_and_serialized_nodes .iter() .map(|(node_id, hash, _)| (*node_id, *hash)) .collect(); ( - MemTrieChanges { node_ids_with_hashes, updated_nodes }, - nodes_hashes_and_serialized + MemTrieChanges { node_ids_with_hashes, updated_nodes: self.updated_nodes }, + hashes_and_serialized_nodes .into_iter() .map(|(_, hash, serialized)| (hash, serialized)) .collect(), @@ -809,19 +807,19 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { /// Converts the updates to memtrie changes only. pub fn to_mem_trie_changes_only(self) -> MemTrieChanges { - let Self { memory: arena, updated_nodes, shard_uid, .. } = self; - let (mem_trie_changes, _) = - Self::to_mem_trie_changes_internal(shard_uid, arena, updated_nodes); + let (mem_trie_changes, _) = self.to_mem_trie_changes_internal(); mem_trie_changes } /// Converts the updates to trie changes as well as memtrie changes. - pub(crate) fn to_trie_changes(self) -> (TrieChanges, TrieAccesses) { - let Self { root, memory: arena, shard_uid, tracked_trie_changes, updated_nodes } = self; - let TrieChangesTracker { mut refcount_changes, accesses } = - tracked_trie_changes.expect("Cannot to_trie_changes for memtrie changes only"); - let (mem_trie_changes, hashes_and_serialized) = - Self::to_mem_trie_changes_internal(shard_uid, arena, updated_nodes); + pub(crate) fn to_trie_changes(mut self) -> (TrieChanges, TrieAccesses) { + let old_root = + self.root.map(|root| root.as_ptr(self.memory).view().node_hash()).unwrap_or_default(); + let TrieChangesTracker { mut refcount_changes, accesses } = self + .tracked_trie_changes + .take() + .expect("Cannot to_trie_changes for memtrie changes only"); + let (mem_trie_changes, hashes_and_serialized) = self.to_mem_trie_changes_internal(); // We've accounted for the dereferenced nodes, as well as value addition/subtractions. // The only thing left is to increment refcount for all new nodes. @@ -832,9 +830,7 @@ impl<'a, M: ArenaMemory> MemTrieUpdate<'a, M> { ( TrieChanges { - old_root: root - .map(|root| root.as_ptr(arena).view().node_hash()) - .unwrap_or_default(), + old_root, new_root: mem_trie_changes .node_ids_with_hashes .last() diff --git a/core/store/src/trie/mod.rs b/core/store/src/trie/mod.rs index fd1929b8d15..21ca9fa1dc7 100644 --- a/core/store/src/trie/mod.rs +++ b/core/store/src/trie/mod.rs @@ -517,8 +517,13 @@ impl TrieRefcountDeltaMap { } } +/// Changes to be applied to in-memory trie. +/// Result is the new state root attached to existing persistent trie structure. #[derive(Default, Clone, PartialEq, Eq, Debug)] pub struct MemTrieChanges { + /// Node ids with hashes of updated nodes. + /// Should be in the post-order traversal of the updated nodes. + /// It implies that the root node is the last one in the list. node_ids_with_hashes: Vec<(UpdatedMemTrieNodeId, CryptoHash)>, updated_nodes: Vec>, }