diff --git a/src/protocol/libp2p/kademlia/config.rs b/src/protocol/libp2p/kademlia/config.rs index f4fb5f29..b3cebfa3 100644 --- a/src/protocol/libp2p/kademlia/config.rs +++ b/src/protocol/libp2p/kademlia/config.rs @@ -41,7 +41,7 @@ use std::{ const DEFAULT_TTL: Duration = Duration::from_secs(36 * 60 * 60); /// Default provider record TTL. -const DEFAULT_PROVIDER_TTL: Duration = Duration::from_secs(48 * 60 * 60); +pub(super) const DEFAULT_PROVIDER_TTL: Duration = Duration::from_secs(48 * 60 * 60); /// Default provider republish interval. pub(super) const DEFAULT_PROVIDER_REFRESH_INTERVAL: Duration = Duration::from_secs(22 * 60 * 60); diff --git a/src/protocol/libp2p/kademlia/handle.rs b/src/protocol/libp2p/kademlia/handle.rs index ad1080f9..15903237 100644 --- a/src/protocol/libp2p/kademlia/handle.rs +++ b/src/protocol/libp2p/kademlia/handle.rs @@ -19,7 +19,7 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - protocol::libp2p::kademlia::{KademliaPeer, PeerRecord, QueryId, Record, RecordKey}, + protocol::libp2p::kademlia::{ContentProvider, PeerRecord, QueryId, Record, RecordKey}, PeerId, }; @@ -148,9 +148,6 @@ pub(crate) enum KademliaCommand { /// Provided key. key: RecordKey, - /// Our external addresses to publish. - public_addresses: Vec, - /// Query ID for the query. query_id: QueryId, }, @@ -210,9 +207,12 @@ pub enum KademliaEvent { /// Query ID. query_id: QueryId, + /// Provided key. + provided_key: RecordKey, + /// Found providers with cached addresses. Returned providers are sorted by distane to the /// provided key. - providers: Vec, + providers: Vec, }, /// `PUT_VALUE` query succeeded. @@ -240,6 +240,15 @@ pub enum KademliaEvent { /// Record. record: Record, }, + + /// Incoming `ADD_PROVIDER` request received. + IncomingProvider { + /// Provided key. + provided_key: RecordKey, + + /// Provider. + provider: ContentProvider, + }, } /// The type of the DHT records. @@ -352,20 +361,9 @@ impl KademliaHandle { /// /// Register the local peer ID & its `public_addresses` as a provider for a given `key`. /// Returns [`Err`] only if `Kademlia` is terminating. - pub async fn start_providing( - &mut self, - key: RecordKey, - public_addresses: Vec, - ) -> QueryId { + pub async fn start_providing(&mut self, key: RecordKey) -> QueryId { let query_id = self.next_query_id(); - let _ = self - .cmd_tx - .send(KademliaCommand::StartProviding { - key, - public_addresses, - query_id, - }) - .await; + let _ = self.cmd_tx.send(KademliaCommand::StartProviding { key, query_id }).await; query_id } diff --git a/src/protocol/libp2p/kademlia/message.rs b/src/protocol/libp2p/kademlia/message.rs index f8f4965f..3c634e54 100644 --- a/src/protocol/libp2p/kademlia/message.rs +++ b/src/protocol/libp2p/kademlia/message.rs @@ -20,7 +20,7 @@ use crate::{ protocol::libp2p::kademlia::{ - record::{Key as RecordKey, ProviderRecord, Record}, + record::{ContentProvider, Key as RecordKey, Record}, schema, types::{ConnectionType, KademliaPeer}, }, @@ -172,14 +172,14 @@ impl KademliaMessage { } /// Create `ADD_PROVIDER` message with `provider`. - pub fn add_provider(provider: ProviderRecord) -> Bytes { + pub fn add_provider(provided_key: RecordKey, provider: ContentProvider) -> Bytes { let peer = KademliaPeer::new( - provider.provider, + provider.peer, provider.addresses, ConnectionType::CanConnect, // ignored by message recipient ); let message = schema::kademlia::Message { - key: provider.key.clone().to_vec(), + key: provided_key.clone().to_vec(), cluster_level_raw: 10, r#type: schema::kademlia::MessageType::AddProvider.into(), provider_peers: std::iter::once((&peer).into()).collect(), @@ -209,16 +209,17 @@ impl KademliaMessage { /// Create `GET_PROVIDERS` response. pub fn get_providers_response( - providers: Vec, + providers: Vec, closer_peers: &[KademliaPeer], ) -> Vec { let provider_peers = providers .into_iter() .map(|p| { KademliaPeer::new( - p.provider, + p.peer, p.addresses, - ConnectionType::CanConnect, // ignored by recipient + // `ConnectionType` is ignored by a recipient + ConnectionType::NotConnected, ) }) .map(|p| (&p).into()) diff --git a/src/protocol/libp2p/kademlia/mod.rs b/src/protocol/libp2p/kademlia/mod.rs index e98fa6e9..40812a3a 100644 --- a/src/protocol/libp2p/kademlia/mod.rs +++ b/src/protocol/libp2p/kademlia/mod.rs @@ -29,7 +29,6 @@ use crate::{ handle::KademliaCommand, message::KademliaMessage, query::{QueryAction, QueryEngine}, - record::ProviderRecord, routing_table::RoutingTable, store::{MemoryStore, MemoryStoreAction, MemoryStoreConfig}, types::{ConnectionType, KademliaPeer, Key}, @@ -61,7 +60,7 @@ pub use handle::{ IncomingRecordValidationMode, KademliaEvent, KademliaHandle, Quorum, RoutingTableUpdateMode, }; pub use query::QueryId; -pub use record::{Key as RecordKey, PeerRecord, Record}; +pub use record::{ContentProvider, Key as RecordKey, PeerRecord, Record}; /// Logging target for the file. const LOG_TARGET: &str = "litep2p::ipfs::kademlia"; @@ -165,9 +164,6 @@ pub(crate) struct Kademlia { /// Default record TTL. record_ttl: Duration, - /// Provider record TTL. - provider_ttl: Duration, - /// Query engine. engine: QueryEngine, @@ -193,6 +189,7 @@ impl Kademlia { local_peer_id, MemoryStoreConfig { provider_refresh_interval: config.provider_refresh_interval, + provider_ttl: config.provider_ttl, ..Default::default() }, ); @@ -212,7 +209,6 @@ impl Kademlia { update_mode: config.update_mode, validation_mode: config.validation_mode, record_ttl: config.record_ttl, - provider_ttl: config.provider_ttl, replication_factor: config.replication_factor, engine: QueryEngine::new(local_peer_id, config.replication_factor, PARALLELISM_FACTOR), } @@ -523,7 +519,7 @@ impl Kademlia { ), } } - KademliaMessage::AddProvider { key, providers } => { + KademliaMessage::AddProvider { key, mut providers } => { tracing::trace!( target: LOG_TARGET, ?peer, @@ -532,15 +528,27 @@ impl Kademlia { "handle `ADD_PROVIDER` message", ); - match (providers.len(), providers.first()) { + match (providers.len(), providers.pop()) { (1, Some(provider)) => if provider.peer == peer { - self.store.put_provider(ProviderRecord { - key, - provider: peer, - addresses: provider.addresses.clone(), - expires: Instant::now() + self.provider_ttl, - }); + self.store.put_provider( + key.clone(), + ContentProvider { + peer, + addresses: provider.addresses.clone(), + }, + ); + + let _ = self + .event_tx + .send(KademliaEvent::IncomingProvider { + provided_key: key, + provider: ContentProvider { + peer: provider.peer, + addresses: provider.addresses, + }, + }) + .await; } else { tracing::trace!( target: LOG_TARGET, @@ -590,10 +598,13 @@ impl Kademlia { "handle `GET_PROVIDERS` request", ); - let providers = self.store.get_providers(key); - // TODO: if local peer is among the providers, update its `ProviderRecord` - // to have up-to-date addresses. - // Requires https://github.com/paritytech/litep2p/issues/211. + let mut providers = self.store.get_providers(key); + + // Make sure local provider addresses are up to date. + let local_peer_id = self.local_key.clone().into_preimage(); + providers.iter_mut().find(|p| p.peer == local_peer_id).as_mut().map(|p| { + p.addresses = self.service.public_addresses().get_addresses(); + }); let closer_peers = self .routing_table @@ -787,16 +798,19 @@ impl Kademlia { Ok(()) } - QueryAction::AddProviderToFoundNodes { provider, peers } => { + QueryAction::AddProviderToFoundNodes { + provided_key, + provider, + peers, + } => { tracing::trace!( target: LOG_TARGET, - provided_key = ?provider.key, + ?provided_key, num_peers = ?peers.len(), "add provider record to found peers", ); - let provided_key = provider.key.clone(); - let message = KademliaMessage::add_provider(provider); + let message = KademliaMessage::add_provider(provided_key.clone(), provider); for peer in peers { if let Err(error) = self.open_substream_or_dial( @@ -828,12 +842,14 @@ impl Kademlia { } QueryAction::GetProvidersQueryDone { query_id, + provided_key, providers, } => { let _ = self .event_tx .send(KademliaEvent::GetProvidersSuccess { query_id, + provided_key, providers, }) .await; @@ -1036,28 +1052,26 @@ impl Kademlia { } Some(KademliaCommand::StartProviding { key, - public_addresses, query_id }) => { tracing::debug!( target: LOG_TARGET, query = ?query_id, ?key, - ?public_addresses, "register as a content provider", ); - let provider = ProviderRecord { - key: key.clone(), - provider: self.service.local_peer_id(), - addresses: public_addresses, - expires: Instant::now() + self.provider_ttl, + let addresses = self.service.public_addresses().get_addresses(); + let provider = ContentProvider { + peer: self.service.local_peer_id(), + addresses, }; - self.store.put_provider(provider.clone()); + self.store.put_provider(key.clone(), provider.clone()); self.engine.start_add_provider( query_id, + key.clone(), provider, self.routing_table .closest(Key::new(key), self.replication_factor) @@ -1105,12 +1119,15 @@ impl Kademlia { Some(KademliaCommand::GetProviders { key, query_id }) => { tracing::debug!(target: LOG_TARGET, ?key, "get providers from DHT"); + let known_providers = self.store.get_providers(&key); + self.engine.start_get_providers( query_id, key.clone(), self.routing_table .closest(Key::new(key), self.replication_factor) .into(), + known_providers, ); } Some(KademliaCommand::AddKnownPeer { peer, addresses }) => { @@ -1151,25 +1168,24 @@ impl Kademlia { } }, action = self.store.next_action() => match action { - Some(MemoryStoreAction::RefreshProvider { mut provider }) => { + Some(MemoryStoreAction::RefreshProvider { provided_key, provider }) => { tracing::trace!( target: LOG_TARGET, - key = ?provider.key, + ?provided_key, "republishing local provider", ); - // Make sure to roll expiration time. - provider.expires = Instant::now() + self.provider_ttl; - - self.store.put_provider(provider.clone()); + self.store.put_provider(provided_key.clone(), provider.clone()); + // We never update local provider addresses in the store when refresh + // it, as this is done anyway when replying to `GET_PROVIDERS` request. - let key = provider.key.clone(); let query_id = self.next_query_id(); self.engine.start_add_provider( query_id, + provided_key.clone(), provider, self.routing_table - .closest(Key::new(key), self.replication_factor) + .closest(Key::new(provided_key), self.replication_factor) .into(), ); } diff --git a/src/protocol/libp2p/kademlia/query/get_providers.rs b/src/protocol/libp2p/kademlia/query/get_providers.rs index d679265c..b2ae19c2 100644 --- a/src/protocol/libp2p/kademlia/query/get_providers.rs +++ b/src/protocol/libp2p/kademlia/query/get_providers.rs @@ -24,8 +24,8 @@ use crate::{ protocol::libp2p::kademlia::{ message::KademliaMessage, query::{QueryAction, QueryId}, - record::Key as RecordKey, - types::{ConnectionType, Distance, KademliaPeer, Key}, + record::{ContentProvider, Key as RecordKey}, + types::{Distance, KademliaPeer, Key}, }, types::multiaddr::Multiaddr, PeerId, @@ -50,6 +50,9 @@ pub struct GetProvidersConfig { /// Target key. pub target: Key, + + /// Known providers from the local store. + pub known_providers: Vec, } #[derive(Debug)] @@ -100,30 +103,35 @@ impl GetProvidersContext { } /// Get the found providers. - pub fn found_providers(self) -> Vec { + pub fn found_providers(self) -> Vec { + Self::merge_and_sort_providers( + self.config.known_providers.into_iter().chain(self.found_providers), + self.config.target, + ) + } + + fn merge_and_sort_providers( + found_providers: impl IntoIterator, + target: Key, + ) -> Vec { // Merge addresses of different provider records of the same peer. let mut providers = HashMap::>::new(); - self.found_providers.into_iter().for_each(|provider| { - providers - .entry(provider.peer) - .or_default() - .extend(provider.addresses.into_iter()) + found_providers.into_iter().for_each(|provider| { + providers.entry(provider.peer).or_default().extend(provider.addresses) }); // Convert into `Vec` let mut providers = providers .into_iter() - .map(|(peer, addresses)| KademliaPeer { - key: Key::from(peer.clone()), + .map(|(peer, addresses)| ContentProvider { peer, addresses: addresses.into_iter().collect(), - connection: ConnectionType::NotConnected, }) .collect::>(); // Sort by the provider distance to the target key. providers.sort_unstable_by(|p1, p2| { - p1.key.distance(&self.config.target).cmp(&p2.key.distance(&self.config.target)) + Key::from(p1.peer).distance(&target).cmp(&Key::from(p2.peer).distance(&target)) }); providers @@ -168,7 +176,7 @@ impl GetProvidersContext { return; }; - self.found_providers.extend(providers.into_iter()); + self.found_providers.extend(providers); // Add the queried peer to `queried` and all new peers which haven't been // queried to `candidates` diff --git a/src/protocol/libp2p/kademlia/query/mod.rs b/src/protocol/libp2p/kademlia/query/mod.rs index 8a247e45..b933ec5b 100644 --- a/src/protocol/libp2p/kademlia/query/mod.rs +++ b/src/protocol/libp2p/kademlia/query/mod.rs @@ -26,7 +26,7 @@ use crate::{ get_providers::{GetProvidersConfig, GetProvidersContext}, get_record::{GetRecordConfig, GetRecordContext}, }, - record::{Key as RecordKey, ProviderRecord, Record}, + record::{ContentProvider, Key as RecordKey, Record}, types::{KademliaPeer, Key}, PeerRecord, Quorum, }, @@ -86,8 +86,11 @@ enum QueryType { /// `ADD_PROVIDER` query. AddProvider { + /// Provided key. + provided_key: RecordKey, + /// Provider record that need to be stored. - provider: ProviderRecord, + provider: ContentProvider, /// Context for the `FIND_NODE` query. context: FindNodeContext, @@ -139,8 +142,11 @@ pub enum QueryAction { /// Add the provider record to nodes closest to the target key. AddProviderToFoundNodes { + /// Provided key. + provided_key: RecordKey, + /// Provider record. - provider: ProviderRecord, + provider: ContentProvider, /// Peers for whom the `ADD_PROVIDER` must be sent to. peers: Vec, @@ -160,8 +166,11 @@ pub enum QueryAction { /// Query ID. query_id: QueryId, + /// Provided key. + provided_key: RecordKey, + /// Found providers. - providers: Vec, + providers: Vec, }, /// Query succeeded. @@ -344,7 +353,8 @@ impl QueryEngine { pub fn start_add_provider( &mut self, query_id: QueryId, - provider: ProviderRecord, + provided_key: RecordKey, + provider: ContentProvider, candidates: VecDeque, ) -> QueryId { tracing::debug!( @@ -355,18 +365,18 @@ impl QueryEngine { "start `ADD_PROVIDER` query", ); - let target = Key::new(provider.key.clone()); let config = FindNodeConfig { local_peer_id: self.local_peer_id, replication_factor: self.replication_factor, parallelism_factor: self.parallelism_factor, query: query_id, - target, + target: Key::new(provided_key.clone()), }; self.queries.insert( query_id, QueryType::AddProvider { + provided_key, provider, context: FindNodeContext::new(config, candidates), }, @@ -381,6 +391,7 @@ impl QueryEngine { query_id: QueryId, key: RecordKey, candidates: VecDeque, + known_providers: Vec, ) -> QueryId { tracing::debug!( target: LOG_TARGET, @@ -396,6 +407,7 @@ impl QueryEngine { parallelism_factor: self.parallelism_factor, query: query_id, target, + known_providers: known_providers.into_iter().map(Into::into).collect(), }; self.queries.insert( @@ -527,12 +539,18 @@ impl QueryEngine { query_id: context.config.query, records: context.found_records(), }, - QueryType::AddProvider { provider, context } => QueryAction::AddProviderToFoundNodes { + QueryType::AddProvider { + provided_key, + provider, + context, + } => QueryAction::AddProviderToFoundNodes { + provided_key, provider, peers: context.responses.into_values().collect::>(), }, QueryType::GetProviders { context } => QueryAction::GetProvidersQueryDone { query_id: context.config.query, + provided_key: context.config.target.clone().into_preimage(), providers: context.found_providers(), }, } diff --git a/src/protocol/libp2p/kademlia/record.rs b/src/protocol/libp2p/kademlia/record.rs index 0c23764a..f8661d20 100644 --- a/src/protocol/libp2p/kademlia/record.rs +++ b/src/protocol/libp2p/kademlia/record.rs @@ -20,7 +20,9 @@ // DEALINGS IN THE SOFTWARE. use crate::{ - protocol::libp2p::kademlia::types::{Distance, Key as KademliaKey}, + protocol::libp2p::kademlia::types::{ + ConnectionType, Distance, KademliaPeer, Key as KademliaKey, + }, Multiaddr, PeerId, }; @@ -152,3 +154,24 @@ impl ProviderRecord { now >= self.expires } } + +/// A user-facing provider type. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct ContentProvider { + // Peer ID of the provider. + pub peer: PeerId, + + // Cached addresses of the provider. + pub addresses: Vec, +} + +impl From for KademliaPeer { + fn from(provider: ContentProvider) -> Self { + Self { + key: KademliaKey::from(provider.peer), + peer: provider.peer, + addresses: provider.addresses, + connection: ConnectionType::NotConnected, + } + } +} diff --git a/src/protocol/libp2p/kademlia/store.rs b/src/protocol/libp2p/kademlia/store.rs index 8d64f6db..6453bdc8 100644 --- a/src/protocol/libp2p/kademlia/store.rs +++ b/src/protocol/libp2p/kademlia/store.rs @@ -22,9 +22,9 @@ use crate::{ protocol::libp2p::kademlia::{ - config::DEFAULT_PROVIDER_REFRESH_INTERVAL, + config::{DEFAULT_PROVIDER_REFRESH_INTERVAL, DEFAULT_PROVIDER_TTL}, futures_stream::FuturesStream, - record::{Key, ProviderRecord, Record}, + record::{ContentProvider, Key, ProviderRecord, Record}, types::Key as KademliaKey, }, PeerId, @@ -41,7 +41,10 @@ const LOG_TARGET: &str = "litep2p::ipfs::kademlia::store"; /// Memory store events. pub enum MemoryStoreAction { - RefreshProvider { provider: ProviderRecord }, + RefreshProvider { + provided_key: Key, + provider: ContentProvider, + }, } /// Memory store. @@ -55,7 +58,7 @@ pub struct MemoryStore { /// Provider records. provider_keys: HashMap>, /// Local providers. - local_providers: HashMap, + local_providers: HashMap, /// Futures to signal it's time to republish a local provider. pending_provider_refresh: FuturesStream>, } @@ -148,7 +151,7 @@ impl MemoryStore { /// Try to get providers from local store for `key`. /// /// Returns a non-empty list of providers, if any. - pub fn get_providers(&mut self, key: &Key) -> Vec { + pub fn get_providers(&mut self, key: &Key) -> Vec { let drop_key = self.provider_keys.get_mut(key).map_or(false, |providers| { let now = std::time::Instant::now(); providers.retain(|p| !p.is_expired(now)); @@ -161,7 +164,16 @@ impl MemoryStore { Vec::default() } else { - self.provider_keys.get(key).cloned().unwrap_or_else(Vec::default) + self.provider_keys + .get(key) + .cloned() + .unwrap_or_else(|| Vec::default()) + .into_iter() + .map(|p| ContentProvider { + peer: p.provider, + addresses: p.addresses, + }) + .collect() } } @@ -170,12 +182,18 @@ impl MemoryStore { /// the furthest already inserted provider. The furthest provider is then discarded. /// /// Returns `true` if the provider was added, `false` otherwise. - pub fn put_provider(&mut self, provider_record: ProviderRecord) -> bool { + pub fn put_provider(&mut self, key: Key, provider: ContentProvider) -> bool { // Helper to schedule local provider refresh. let mut schedule_local_provider_refresh = |provider_record: ProviderRecord| { let key = provider_record.key.clone(); let refresh_interval = self.config.provider_refresh_interval; - self.local_providers.insert(key.clone(), provider_record); + self.local_providers.insert( + key.clone(), + ContentProvider { + peer: provider_record.provider, + addresses: provider_record.addresses, + }, + ); self.pending_provider_refresh.push(Box::pin(async move { tokio::time::sleep(refresh_interval).await; key @@ -184,7 +202,12 @@ impl MemoryStore { // Make sure we have no more than `max_provider_addresses`. let provider_record = { - let mut record = provider_record; + let mut record = ProviderRecord { + key, + provider: provider.peer, + addresses: provider.addresses, + expires: std::time::Instant::now() + self.config.provider_ttl, + }; record.addresses.truncate(self.config.max_provider_addresses); record }; @@ -272,14 +295,13 @@ impl MemoryStore { Entry::Vacant(_) => { tracing::error!(?key, "local provider key not found during removal",); debug_assert!(false); - return; } Entry::Occupied(mut entry) => { let providers = entry.get_mut(); // Providers are sorted by distance. - let local_provider_distance = KademliaKey::from(self.local_peer_id.clone()) - .distance(&KademliaKey::new(key.clone())); + let local_provider_distance = + KademliaKey::from(self.local_peer_id).distance(&KademliaKey::new(key.clone())); let provider_position = providers.binary_search_by(|p| p.distance().cmp(&local_provider_distance)); @@ -303,30 +325,29 @@ impl MemoryStore { /// Poll next action from the store. pub async fn next_action(&mut self) -> Option { - // [`FuturesStream`] never terminates, so `map()` below is always triggered. - self.pending_provider_refresh - .next() - .await - .map(|key| { - if let Some(provider) = self.local_providers.get(&key).cloned() { - tracing::trace!( - target: LOG_TARGET, - ?key, - "refresh provider" - ); - - Some(MemoryStoreAction::RefreshProvider { provider }) - } else { - tracing::trace!( - target: LOG_TARGET, - ?key, - "it's time to refresh a provider, but we do not provide this key anymore", - ); - - None - } - }) - .flatten() + // [`FuturesStream`] never terminates, so `and_then()` below is always triggered. + self.pending_provider_refresh.next().await.and_then(|key| { + if let Some(provider) = self.local_providers.get(&key).cloned() { + tracing::trace!( + target: LOG_TARGET, + ?key, + "refresh provider" + ); + + Some(MemoryStoreAction::RefreshProvider { + provided_key: key, + provider, + }) + } else { + tracing::trace!( + target: LOG_TARGET, + ?key, + "it's time to refresh a provider, but we do not provide this key anymore", + ); + + None + } + }) } } @@ -349,6 +370,9 @@ pub struct MemoryStoreConfig { /// Local providers republish interval. pub provider_refresh_interval: Duration, + + /// Provider record TTL. + pub provider_ttl: Duration, } impl Default for MemoryStoreConfig { @@ -360,6 +384,7 @@ impl Default for MemoryStoreConfig { max_provider_addresses: 30, max_providers_per_key: 20, provider_refresh_interval: DEFAULT_PROVIDER_REFRESH_INTERVAL, + provider_ttl: DEFAULT_PROVIDER_TTL, } } } @@ -367,7 +392,7 @@ impl Default for MemoryStoreConfig { #[cfg(test)] mod tests { use super::*; - use crate::PeerId; + use crate::{protocol::libp2p::kademlia::types::Key as KademliaKey, PeerId}; use multiaddr::multiaddr; #[test] @@ -468,36 +493,31 @@ mod tests { #[test] fn put_get_provider() { let mut store = MemoryStore::new(PeerId::random()); - let provider = ProviderRecord { - key: Key::from(vec![1, 2, 3]), - provider: PeerId::random(), + let key = Key::from(vec![1, 2, 3]); + let provider = ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }; - store.put_provider(provider.clone()); - assert_eq!(store.get_providers(&provider.key), vec![provider]); + store.put_provider(key.clone(), provider.clone()); + assert_eq!(store.get_providers(&key), vec![provider]); } #[test] fn multiple_providers_per_key() { let mut store = MemoryStore::new(PeerId::random()); let key = Key::from(vec![1, 2, 3]); - let provider1 = ProviderRecord { - key: key.clone(), - provider: PeerId::random(), + let provider1 = ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }; - let provider2 = ProviderRecord { - key: key.clone(), - provider: PeerId::random(), + let provider2 = ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }; - store.put_provider(provider1.clone()); - store.put_provider(provider2.clone()); + store.put_provider(key.clone(), provider1.clone()); + store.put_provider(key.clone(), provider2.clone()); let got_providers = store.get_providers(&key); assert_eq!(got_providers.len(), 2); @@ -510,21 +530,24 @@ mod tests { let mut store = MemoryStore::new(PeerId::random()); let key = Key::from(vec![1, 2, 3]); let providers = (0..10) - .map(|_| ProviderRecord { - key: key.clone(), - provider: PeerId::random(), + .map(|_| ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }) .collect::>(); providers.iter().for_each(|p| { - store.put_provider(p.clone()); + store.put_provider(key.clone(), p.clone()); }); let sorted_providers = { + let target = KademliaKey::new(key.clone()); let mut providers = providers; - providers.sort_unstable_by_key(ProviderRecord::distance); + providers.sort_by(|p1, p2| { + KademliaKey::from(p1.peer.clone()) + .distance(&target) + .cmp(&KademliaKey::from(p2.peer.clone()).distance(&target)) + }); providers }; @@ -542,16 +565,14 @@ mod tests { ); let key = Key::from(vec![1, 2, 3]); let providers = (0..20) - .map(|_| ProviderRecord { - key: key.clone(), - provider: PeerId::random(), + .map(|_| ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }) .collect::>(); providers.iter().for_each(|p| { - store.put_provider(p.clone()); + store.put_provider(key.clone(), p.clone()); }); assert_eq!(store.get_providers(&key).len(), 10); } @@ -567,21 +588,24 @@ mod tests { ); let key = Key::from(vec![1, 2, 3]); let providers = (0..20) - .map(|_| ProviderRecord { - key: key.clone(), - provider: PeerId::random(), + .map(|_| ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }) .collect::>(); providers.iter().for_each(|p| { - store.put_provider(p.clone()); + store.put_provider(key.clone(), p.clone()); }); let closest_providers = { + let target = KademliaKey::new(key.clone()); let mut providers = providers; - providers.sort_unstable_by_key(ProviderRecord::distance); + providers.sort_by(|p1, p2| { + KademliaKey::from(p1.peer.clone()) + .distance(&target) + .cmp(&KademliaKey::from(p2.peer.clone()).distance(&target)) + }); providers.truncate(10); providers }; @@ -600,28 +624,31 @@ mod tests { ); let key = Key::from(vec![1, 2, 3]); let providers = (0..11) - .map(|_| ProviderRecord { - key: key.clone(), - provider: PeerId::random(), + .map(|_| ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }) .collect::>(); let sorted_providers = { + let target = KademliaKey::new(key.clone()); let mut providers = providers; - providers.sort_unstable_by_key(ProviderRecord::distance); + providers.sort_by(|p1, p2| { + KademliaKey::from(p1.peer.clone()) + .distance(&target) + .cmp(&KademliaKey::from(p2.peer.clone()).distance(&target)) + }); providers }; // First 10 providers are inserted. for i in 0..10 { - assert!(store.put_provider(sorted_providers[i].clone())); + assert!(store.put_provider(key.clone(), sorted_providers[i].clone())); } assert_eq!(store.get_providers(&key), sorted_providers[..10]); // The furthests provider doesn't fit. - assert!(!store.put_provider(sorted_providers[10].clone())); + assert!(!store.put_provider(key.clone(), sorted_providers[10].clone())); assert_eq!(store.get_providers(&key), sorted_providers[..10]); } @@ -639,40 +666,41 @@ mod tests { let peer_id0 = peer_ids[0]; let providers = peer_ids .iter() - .map(|peer_id| ProviderRecord { - key: key.clone(), - provider: *peer_id, + .map(|peer_id| ContentProvider { + peer: *peer_id, addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }) .collect::>(); providers.iter().for_each(|p| { - store.put_provider(p.clone()); + store.put_provider(key.clone(), p.clone()); }); let sorted_providers = { + let target = KademliaKey::new(key.clone()); let mut providers = providers; - providers.sort_unstable_by_key(ProviderRecord::distance); + providers.sort_by(|p1, p2| { + KademliaKey::from(p1.peer.clone()) + .distance(&target) + .cmp(&KademliaKey::from(p2.peer.clone()).distance(&target)) + }); providers }; assert_eq!(store.get_providers(&key), sorted_providers); - let provider0_new = ProviderRecord { - key: key.clone(), - provider: peer_id0, + let provider0_new = ContentProvider { + peer: peer_id0, addresses: vec![multiaddr!(Ip4([192, 168, 0, 1]), Tcp(20000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }; // Provider is updated in place. - assert!(store.put_provider(provider0_new.clone())); + assert!(store.put_provider(key.clone(), provider0_new.clone())); let providers_new = sorted_providers .into_iter() .map(|p| { - if p.provider == peer_id0 { + if p.peer == peer_id0 { provider0_new.clone() } else { p @@ -683,46 +711,67 @@ mod tests { assert_eq!(store.get_providers(&key), providers_new); } - #[test] - fn provider_record_expires() { - let mut store = MemoryStore::new(PeerId::random()); - let provider = ProviderRecord { - key: Key::from(vec![1, 2, 3]), - provider: PeerId::random(), + #[tokio::test] + async fn provider_record_expires() { + let mut store = MemoryStore::with_config( + PeerId::random(), + MemoryStoreConfig { + provider_ttl: std::time::Duration::from_secs(1), + ..Default::default() + }, + ); + let key = Key::from(vec![1, 2, 3]); + let provider = ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() - std::time::Duration::from_secs(5), }; - // Provider record is already expired. - assert!(provider.is_expired(std::time::Instant::now())); + store.put_provider(key.clone(), provider.clone()); + + // Provider does not instantly expire. + assert_eq!(store.get_providers(&key), vec![provider]); - store.put_provider(provider.clone()); - assert!(store.get_providers(&provider.key).is_empty()); + // Provider expires after 2 seconds. + tokio::time::sleep(Duration::from_secs(2)).await; + assert_eq!(store.get_providers(&key), vec![]); } - #[test] - fn individual_provider_record_expires() { - let mut store = MemoryStore::new(PeerId::random()); + #[tokio::test] + async fn individual_provider_record_expires() { + let mut store = MemoryStore::with_config( + PeerId::random(), + MemoryStoreConfig { + provider_ttl: std::time::Duration::from_secs(8), + ..Default::default() + }, + ); let key = Key::from(vec![1, 2, 3]); - let provider1 = ProviderRecord { - key: key.clone(), - provider: PeerId::random(), + let provider1 = ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() - std::time::Duration::from_secs(5), }; - let provider2 = ProviderRecord { - key: key.clone(), - provider: PeerId::random(), + let provider2 = ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }; - assert!(provider1.is_expired(std::time::Instant::now())); + store.put_provider(key.clone(), provider1.clone()); + tokio::time::sleep(Duration::from_secs(4)).await; + store.put_provider(key.clone(), provider2.clone()); - store.put_provider(provider1.clone()); - store.put_provider(provider2.clone()); + // Providers do not instantly expire. + let got_providers = store.get_providers(&key); + assert_eq!(got_providers.len(), 2); + assert!(got_providers.contains(&provider1)); + assert!(got_providers.contains(&provider2)); + // First provider expires. + tokio::time::sleep(Duration::from_secs(6)).await; assert_eq!(store.get_providers(&key), vec![provider2]); + + // Second provider expires. + tokio::time::sleep(Duration::from_secs(4)).await; + assert_eq!(store.get_providers(&key), vec![]); } #[test] @@ -735,9 +784,8 @@ mod tests { }, ); let key = Key::from(vec![1, 2, 3]); - let provider = ProviderRecord { - key: Key::from(vec![1, 2, 3]), - provider: PeerId::random(), + let provider = ContentProvider { + peer: PeerId::random(), addresses: vec![ multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10000u16)), multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10001u16)), @@ -745,14 +793,12 @@ mod tests { multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10003u16)), multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10004u16)), ], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }; - store.put_provider(provider); + store.put_provider(key.clone(), provider); let got_providers = store.get_providers(&key); assert_eq!(got_providers.len(), 1); - assert_eq!(got_providers.first().unwrap().key, key); assert_eq!(got_providers.first().unwrap().addresses.len(), 2); } @@ -766,32 +812,29 @@ mod tests { }, ); - let provider1 = ProviderRecord { - key: Key::from(vec![1, 2, 3]), - provider: PeerId::random(), + let key1 = Key::from(vec![1, 1, 1]); + let provider1 = ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10001u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }; - let provider2 = ProviderRecord { - key: Key::from(vec![4, 5, 6]), - provider: PeerId::random(), + let key2 = Key::from(vec![2, 2, 2]); + let provider2 = ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10002u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }; - let provider3 = ProviderRecord { - key: Key::from(vec![7, 8, 9]), - provider: PeerId::random(), + let key3 = Key::from(vec![3, 3, 3]); + let provider3 = ContentProvider { + peer: PeerId::random(), addresses: vec![multiaddr!(Ip4([127, 0, 0, 1]), Tcp(10003u16))], - expires: std::time::Instant::now() + std::time::Duration::from_secs(3600), }; - assert!(store.put_provider(provider1.clone())); - assert!(store.put_provider(provider2.clone())); - assert!(!store.put_provider(provider3.clone())); + assert!(store.put_provider(key1.clone(), provider1.clone())); + assert!(store.put_provider(key2.clone(), provider2.clone())); + assert!(!store.put_provider(key3.clone(), provider3.clone())); - assert_eq!(store.get_providers(&provider1.key), vec![provider1]); - assert_eq!(store.get_providers(&provider2.key), vec![provider2]); - assert_eq!(store.get_providers(&provider3.key), vec![]); + assert_eq!(store.get_providers(&key1), vec![provider1]); + assert_eq!(store.get_providers(&key2), vec![provider2]); + assert_eq!(store.get_providers(&key3), vec![]); } // TODO: test local providers.