From 608c8adfd575ffb26d4485ebd74f42aabc64c6ad Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 19 Jan 2023 04:41:02 +0000 Subject: [PATCH 1/7] Update the lightning graph snapshot used in benchmarks The previous copy was more than one and a half years old, the lightning network has changed a lot since! As of this commit, performance on my Xeon W-10885M with a SK hynix Gold P31 storing a BTRFS volume is as follows: ``` test ln::channelmanager::bench::bench_sends ... bench: 5,896,492 ns/iter (+/- 512,421) test routing::gossip::benches::read_network_graph ... bench: 1,645,740,604 ns/iter (+/- 47,611,514) test routing::gossip::benches::write_network_graph ... bench: 234,870,775 ns/iter (+/- 8,301,775) test routing::router::benches::generate_mpp_routes_with_probabilistic_scorer ... bench: 166,155,032 ns/iter (+/- 30,206,162) test routing::router::benches::generate_mpp_routes_with_zero_penalty_scorer ... bench: 136,843,661 ns/iter (+/- 67,111,218) test routing::router::benches::generate_routes_with_probabilistic_scorer ... bench: 52,954,598 ns/iter (+/- 11,360,547) test routing::router::benches::generate_routes_with_zero_penalty_scorer ... bench: 37,598,126 ns/iter (+/- 17,262,519) test bench::bench_sends ... bench: 37,760,922 ns/iter (+/- 5,179,123) test bench::bench_reading_full_graph_from_file ... bench: 25,615 ns/iter (+/- 1,149) ``` --- .github/workflows/build.yml | 12 ++++++------ lightning/src/routing/router.rs | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f27a2ccf863..340b7f898d9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -242,19 +242,19 @@ jobs: id: cache-graph uses: actions/cache@v3 with: - path: lightning/net_graph-2021-05-31.bin - key: ldk-net_graph-v0.0.15-2021-05-31.bin + path: lightning/net_graph-2023-01-18.bin + key: ldk-net_graph-v0.0.113-2023-01-18.bin - name: Fetch routing graph snapshot if: steps.cache-graph.outputs.cache-hit != 'true' run: | - curl --verbose -L -o lightning/net_graph-2021-05-31.bin https://bitcoin.ninja/ldk-net_graph-v0.0.15-2021-05-31.bin - echo "Sha sum: $(sha256sum lightning/net_graph-2021-05-31.bin | awk '{ print $1 }')" - if [ "$(sha256sum lightning/net_graph-2021-05-31.bin | awk '{ print $1 }')" != "${EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM}" ]; then + curl --verbose -L -o lightning/net_graph-2023-01-18.bin https://bitcoin.ninja/ldk-net_graph-v0.0.113-2023-01-18.bin + echo "Sha sum: $(sha256sum lightning/net_graph-2023-01-18.bin | awk '{ print $1 }')" + if [ "$(sha256sum lightning/net_graph-2023-01-18.bin | awk '{ print $1 }')" != "${EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM}" ]; then echo "Bad hash" exit 1 fi env: - EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM: 05a5361278f68ee2afd086cc04a1f927a63924be451f3221d380533acfacc303 + EXPECTED_ROUTING_GRAPH_SNAPSHOT_SHASUM: da6066f2bddcddbe7d8a6debbd53545697137b310bbb8c4911bc8c81fc5ff48c - name: Fetch rapid graph sync reference input run: | curl --verbose -L -o lightning-rapid-gossip-sync/res/full_graph.lngossip https://bitcoin.ninja/ldk-compressed_graph-285cb27df79-2022-07-21.bin diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index c15b612d939..55d7f01494c 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -5639,8 +5639,8 @@ pub(crate) mod bench_utils { use std::fs::File; /// Tries to open a network graph file, or panics with a URL to fetch it. pub(crate) fn get_route_file() -> Result { - let res = File::open("net_graph-2021-05-31.bin") // By default we're run in RL/lightning - .or_else(|_| File::open("lightning/net_graph-2021-05-31.bin")) // We may be run manually in RL/ + let res = File::open("net_graph-2023-01-18.bin") // By default we're run in RL/lightning + .or_else(|_| File::open("lightning/net_graph-2023-01-18.bin")) // We may be run manually in RL/ .or_else(|_| { // Fall back to guessing based on the binary location // path is likely something like .../rust-lightning/target/debug/deps/lightning-... let mut path = std::env::current_exe().unwrap(); @@ -5649,11 +5649,11 @@ pub(crate) mod bench_utils { path.pop(); // debug path.pop(); // target path.push("lightning"); - path.push("net_graph-2021-05-31.bin"); + path.push("net_graph-2023-01-18.bin"); eprintln!("{}", path.to_str().unwrap()); File::open(path) }) - .map_err(|_| "Please fetch https://bitcoin.ninja/ldk-net_graph-v0.0.15-2021-05-31.bin and place it at lightning/net_graph-2021-05-31.bin"); + .map_err(|_| "Please fetch https://bitcoin.ninja/ldk-net_graph-v0.0.113-2023-01-18.bin and place it at lightning/net_graph-2023-01-18.bin"); #[cfg(require_route_graph_test)] return Ok(res.unwrap()); #[cfg(not(require_route_graph_test))] From efdd2217b78b1b5b3cbbcb8f4bce26f1c6a0cf87 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 19 Jan 2023 18:24:30 +0000 Subject: [PATCH 2/7] Update min-inbound-fee values on `NetworkGraph` load Historically we've had various bugs in keeping the `lowest_inbound_channel_fees` field in `NodeInfo` up-to-date as we go. This leaves the A* routing less efficient as it can't prune hops as aggressively. In order to get accurate benchmarks, this commit updates the minimum-inbound-fees field on load. This is not the most efficient way of doing so, but suffices for fetching benchmarks and will be removed in the coming commits. Note that this is *slower* than the non-updating version in the previous commit. While I haven't dug into this incredibly deeply, the graph snapshot in use has min-fee info for only 9,618 of 20,818 nodes. Thus, it is my guess that with the graph snapshot as-is the branch predictor is able to largely remove the A* heuristic lookups, but with this change it is forced to wait for A* heuristic map lookups to complete, causing a performance regression. ``` test routing::router::benches::generate_mpp_routes_with_probabilistic_scorer ... bench: 182,980,059 ns/iter (+/- 32,662,047) test routing::router::benches::generate_mpp_routes_with_zero_penalty_scorer ... bench: 151,170,457 ns/iter (+/- 75,351,011) test routing::router::benches::generate_routes_with_probabilistic_scorer ... bench: 58,187,277 ns/iter (+/- 11,606,440) test routing::router::benches::generate_routes_with_zero_penalty_scorer ... bench: 41,210,193 ns/iter (+/- 18,103,320) ``` --- lightning/src/routing/gossip.rs | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/lightning/src/routing/gossip.rs b/lightning/src/routing/gossip.rs index 1a5b978502c..065472aa3c1 100644 --- a/lightning/src/routing/gossip.rs +++ b/lightning/src/routing/gossip.rs @@ -1156,14 +1156,14 @@ impl ReadableArgs for NetworkGraph where L::Target: Logger { let genesis_hash: BlockHash = Readable::read(reader)?; let channels_count: u64 = Readable::read(reader)?; - let mut channels = BTreeMap::new(); + let mut channels: BTreeMap = BTreeMap::new(); for _ in 0..channels_count { let chan_id: u64 = Readable::read(reader)?; let chan_info = Readable::read(reader)?; channels.insert(chan_id, chan_info); } let nodes_count: u64 = Readable::read(reader)?; - let mut nodes = BTreeMap::new(); + let mut nodes: BTreeMap = BTreeMap::new(); for _ in 0..nodes_count { let node_id = Readable::read(reader)?; let node_info = Readable::read(reader)?; @@ -1175,6 +1175,22 @@ impl ReadableArgs for NetworkGraph where L::Target: Logger { (1, last_rapid_gossip_sync_timestamp, option), }); + // Regenerate inbound fees for all channels. The live-updating of these has been broken in + // various ways historically, so this ensures that we have up-to-date limits. + for (node_id, node) in nodes.iter_mut() { + let mut best_fees = RoutingFees { base_msat: u32::MAX, proportional_millionths: u32::MAX }; + for channel in node.channels.iter() { + if let Some(chan) = channels.get(channel) { + let dir_opt = if *node_id == chan.node_one { &chan.two_to_one } else { &chan.one_to_two }; + if let Some(dir) = dir_opt { + best_fees.base_msat = cmp::min(best_fees.base_msat, dir.fees.base_msat); + best_fees.proportional_millionths = cmp::min(best_fees.proportional_millionths, dir.fees.proportional_millionths); + } + } else { return Err(DecodeError::InvalidValue); } + } + node.lowest_inbound_channel_fees = Some(best_fees); + } + Ok(NetworkGraph { secp_ctx: Secp256k1::verification_only(), genesis_hash, From a3f7b790b45698ccc33d9f08265a1863688c08fe Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Tue, 25 Oct 2022 03:15:03 +0000 Subject: [PATCH 3/7] Drop A* implementation in the router for simple Dijkstra's As evidenced by the previous commit, it appears our A* router does worse than a more naive approach. This isn't super surpsising, as the A* heuristic calculation requires a map lookup, which is relatively expensive. ``` test routing::router::benches::generate_mpp_routes_with_probabilistic_scorer ... bench: 169,991,943 ns/iter (+/- 30,838,048) test routing::router::benches::generate_mpp_routes_with_zero_penalty_scorer ... bench: 122,144,987 ns/iter (+/- 61,708,911) test routing::router::benches::generate_routes_with_probabilistic_scorer ... bench: 48,546,068 ns/iter (+/- 10,379,642) test routing::router::benches::generate_routes_with_zero_penalty_scorer ... bench: 32,898,557 ns/iter (+/- 14,157,641) ``` --- lightning/src/routing/gossip.rs | 83 ++++----------------------------- lightning/src/routing/router.rs | 49 +++---------------- 2 files changed, 16 insertions(+), 116 deletions(-) diff --git a/lightning/src/routing/gossip.rs b/lightning/src/routing/gossip.rs index 065472aa3c1..a12b3d563da 100644 --- a/lightning/src/routing/gossip.rs +++ b/lightning/src/routing/gossip.rs @@ -1054,10 +1054,6 @@ impl Readable for NodeAlias { pub struct NodeInfo { /// All valid channels a node has announced pub channels: Vec, - /// Lowest fees enabling routing via any of the enabled, known channels to a node. - /// The two fields (flat and proportional fee) are independent, - /// meaning they don't have to refer to the same channel. - pub lowest_inbound_channel_fees: Option, /// More information about a node from node_announcement. /// Optional because we store a Node entry after learning about it from /// a channel announcement, but before receiving a node announcement. @@ -1066,8 +1062,8 @@ pub struct NodeInfo { impl fmt::Display for NodeInfo { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { - write!(f, "lowest_inbound_channel_fees: {:?}, channels: {:?}, announcement_info: {:?}", - self.lowest_inbound_channel_fees, &self.channels[..], self.announcement_info)?; + write!(f, " channels: {:?}, announcement_info: {:?}", + &self.channels[..], self.announcement_info)?; Ok(()) } } @@ -1075,7 +1071,7 @@ impl fmt::Display for NodeInfo { impl Writeable for NodeInfo { fn write(&self, writer: &mut W) -> Result<(), io::Error> { write_tlv_fields!(writer, { - (0, self.lowest_inbound_channel_fees, option), + // Note that older versions of LDK wrote the lowest inbound fees here at type 0 (2, self.announcement_info, option), (4, self.channels, vec_type), }); @@ -1103,18 +1099,22 @@ impl MaybeReadable for NodeAnnouncementInfoDeserWrapper { impl Readable for NodeInfo { fn read(reader: &mut R) -> Result { - _init_tlv_field_var!(lowest_inbound_channel_fees, option); + // Historically, we tracked the lowest inbound fees for any node in order to use it as an + // A* heuristic when routing. Sadly, these days many, many nodes have at least one channel + // with zero inbound fees, causing that heuristic to provide little gain. Worse, because it + // requires additional complexity and lookups during routing, it ends up being a + // performance loss. Thus, we simply ignore the old field here and no longer track it. + let mut _lowest_inbound_channel_fees: Option = None; let mut announcement_info_wrap: Option = None; _init_tlv_field_var!(channels, vec_type); read_tlv_fields!(reader, { - (0, lowest_inbound_channel_fees, option), + (0, _lowest_inbound_channel_fees, option), (2, announcement_info_wrap, ignorable), (4, channels, vec_type), }); Ok(NodeInfo { - lowest_inbound_channel_fees: _init_tlv_based_struct_field!(lowest_inbound_channel_fees, option), announcement_info: announcement_info_wrap.map(|w| w.0), channels: _init_tlv_based_struct_field!(channels, vec_type), }) @@ -1175,22 +1175,6 @@ impl ReadableArgs for NetworkGraph where L::Target: Logger { (1, last_rapid_gossip_sync_timestamp, option), }); - // Regenerate inbound fees for all channels. The live-updating of these has been broken in - // various ways historically, so this ensures that we have up-to-date limits. - for (node_id, node) in nodes.iter_mut() { - let mut best_fees = RoutingFees { base_msat: u32::MAX, proportional_millionths: u32::MAX }; - for channel in node.channels.iter() { - if let Some(chan) = channels.get(channel) { - let dir_opt = if *node_id == chan.node_one { &chan.two_to_one } else { &chan.one_to_two }; - if let Some(dir) = dir_opt { - best_fees.base_msat = cmp::min(best_fees.base_msat, dir.fees.base_msat); - best_fees.proportional_millionths = cmp::min(best_fees.proportional_millionths, dir.fees.proportional_millionths); - } - } else { return Err(DecodeError::InvalidValue); } - } - node.lowest_inbound_channel_fees = Some(best_fees); - } - Ok(NetworkGraph { secp_ctx: Secp256k1::verification_only(), genesis_hash, @@ -1430,7 +1414,6 @@ impl NetworkGraph where L::Target: Logger { BtreeEntry::Vacant(node_entry) => { node_entry.insert(NodeInfo { channels: vec!(short_channel_id), - lowest_inbound_channel_fees: None, announcement_info: None, }); } @@ -1731,9 +1714,7 @@ impl NetworkGraph where L::Target: Logger { } fn update_channel_intern(&self, msg: &msgs::UnsignedChannelUpdate, full_msg: Option<&msgs::ChannelUpdate>, sig: Option<&secp256k1::ecdsa::Signature>) -> Result<(), LightningError> { - let dest_node_id; let chan_enabled = msg.flags & (1 << 1) != (1 << 1); - let chan_was_enabled; #[cfg(all(feature = "std", not(test), not(feature = "_test_utils")))] { @@ -1781,9 +1762,6 @@ impl NetworkGraph where L::Target: Logger { } else if existing_chan_info.last_update == msg.timestamp { return Err(LightningError{err: "Update had same timestamp as last processed update".to_owned(), action: ErrorAction::IgnoreDuplicateGossip}); } - chan_was_enabled = existing_chan_info.enabled; - } else { - chan_was_enabled = false; } } } @@ -1811,7 +1789,6 @@ impl NetworkGraph where L::Target: Logger { let msg_hash = hash_to_message!(&Sha256dHash::hash(&msg.encode()[..])[..]); if msg.flags & 1 == 1 { - dest_node_id = channel.node_one.clone(); check_update_latest!(channel.two_to_one); if let Some(sig) = sig { secp_verify_sig!(self.secp_ctx, &msg_hash, &sig, &PublicKey::from_slice(channel.node_two.as_slice()).map_err(|_| LightningError{ @@ -1821,7 +1798,6 @@ impl NetworkGraph where L::Target: Logger { } channel.two_to_one = get_new_channel_info!(); } else { - dest_node_id = channel.node_two.clone(); check_update_latest!(channel.one_to_two); if let Some(sig) = sig { secp_verify_sig!(self.secp_ctx, &msg_hash, &sig, &PublicKey::from_slice(channel.node_one.as_slice()).map_err(|_| LightningError{ @@ -1834,44 +1810,6 @@ impl NetworkGraph where L::Target: Logger { } } - let mut nodes = self.nodes.write().unwrap(); - if chan_enabled { - let node = nodes.get_mut(&dest_node_id).unwrap(); - let mut base_msat = msg.fee_base_msat; - let mut proportional_millionths = msg.fee_proportional_millionths; - if let Some(fees) = node.lowest_inbound_channel_fees { - base_msat = cmp::min(base_msat, fees.base_msat); - proportional_millionths = cmp::min(proportional_millionths, fees.proportional_millionths); - } - node.lowest_inbound_channel_fees = Some(RoutingFees { - base_msat, - proportional_millionths - }); - } else if chan_was_enabled { - let node = nodes.get_mut(&dest_node_id).unwrap(); - let mut lowest_inbound_channel_fees = None; - - for chan_id in node.channels.iter() { - let chan = channels.get(chan_id).unwrap(); - let chan_info_opt; - if chan.node_one == dest_node_id { - chan_info_opt = chan.two_to_one.as_ref(); - } else { - chan_info_opt = chan.one_to_two.as_ref(); - } - if let Some(chan_info) = chan_info_opt { - if chan_info.enabled { - let fees = lowest_inbound_channel_fees.get_or_insert(RoutingFees { - base_msat: u32::max_value(), proportional_millionths: u32::max_value() }); - fees.base_msat = cmp::min(fees.base_msat, chan_info.fees.base_msat); - fees.proportional_millionths = cmp::min(fees.proportional_millionths, chan_info.fees.proportional_millionths); - } - } - } - - node.lowest_inbound_channel_fees = lowest_inbound_channel_fees; - } - Ok(()) } @@ -3291,7 +3229,6 @@ mod tests { // 2. Check we can read a NodeInfo anyways, but set the NodeAnnouncementInfo to None if invalid let valid_node_info = NodeInfo { channels: Vec::new(), - lowest_inbound_channel_fees: None, announcement_info: Some(valid_node_ann_info), }; diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 55d7f01494c..e4b95a90d3b 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -582,7 +582,6 @@ impl_writeable_tlv_based!(RouteHintHop, { #[derive(Eq, PartialEq)] struct RouteGraphNode { node_id: NodeId, - lowest_fee_to_peer_through_node: u64, lowest_fee_to_node: u64, total_cltv_delta: u32, // The maximum value a yet-to-be-constructed payment path might flow through this node. @@ -603,9 +602,9 @@ struct RouteGraphNode { impl cmp::Ord for RouteGraphNode { fn cmp(&self, other: &RouteGraphNode) -> cmp::Ordering { - let other_score = cmp::max(other.lowest_fee_to_peer_through_node, other.path_htlc_minimum_msat) + let other_score = cmp::max(other.lowest_fee_to_node, other.path_htlc_minimum_msat) .saturating_add(other.path_penalty_msat); - let self_score = cmp::max(self.lowest_fee_to_peer_through_node, self.path_htlc_minimum_msat) + let self_score = cmp::max(self.lowest_fee_to_node, self.path_htlc_minimum_msat) .saturating_add(self.path_penalty_msat); other_score.cmp(&self_score).then_with(|| other.node_id.cmp(&self.node_id)) } @@ -729,8 +728,6 @@ struct PathBuildingHop<'a> { candidate: CandidateRouteHop<'a>, fee_msat: u64, - /// Minimal fees required to route to the source node of the current hop via any of its inbound channels. - src_lowest_inbound_fees: RoutingFees, /// All the fees paid *after* this channel on the way to the destination next_hops_fee_msat: u64, /// Fee paid for the use of the current channel (see candidate.fees()). @@ -1007,9 +1004,8 @@ where L::Target: Logger { // 8. If our maximum channel saturation limit caused us to pick two identical paths, combine // them so that we're not sending two HTLCs along the same path. - // As for the actual search algorithm, - // we do a payee-to-payer pseudo-Dijkstra's sorting by each node's distance from the payee - // plus the minimum per-HTLC fee to get from it to another node (aka "shitty pseudo-A*"). + // As for the actual search algorithm, we do a payee-to-payer Dijkstra's sorting by each node's + // distance from the payee // // We are not a faithful Dijkstra's implementation because we can change values which impact // earlier nodes while processing later nodes. Specifically, if we reach a channel with a lower @@ -1044,10 +1040,6 @@ where L::Target: Logger { // runtime for little gain. Specifically, the current algorithm rather efficiently explores the // graph for candidate paths, calculating the maximum value which can realistically be sent at // the same time, remaining generic across different payment values. - // - // TODO: There are a few tweaks we could do, including possibly pre-calculating more stuff - // to use as the A* heuristic beyond just the cost to get one node further than the current - // one. let network_channels = network_graph.channels(); let network_nodes = network_graph.nodes(); @@ -1097,7 +1089,7 @@ where L::Target: Logger { } } - // The main heap containing all candidate next-hops sorted by their score (max(A* fee, + // The main heap containing all candidate next-hops sorted by their score (max(fee, // htlc_minimum)). Ideally this would be a heap which allowed cheap score reduction instead of // adding duplicate entries when we find a better path to a given node. let mut targets: BinaryHeap = BinaryHeap::new(); @@ -1273,20 +1265,10 @@ where L::Target: Logger { // semi-dummy record just to compute the fees to reach the source node. // This will affect our decision on selecting short_channel_id // as a way to reach the $dest_node_id. - let mut fee_base_msat = 0; - let mut fee_proportional_millionths = 0; - if let Some(Some(fees)) = network_nodes.get(&$src_node_id).map(|node| node.lowest_inbound_channel_fees) { - fee_base_msat = fees.base_msat; - fee_proportional_millionths = fees.proportional_millionths; - } PathBuildingHop { node_id: $dest_node_id.clone(), candidate: $candidate.clone(), fee_msat: 0, - src_lowest_inbound_fees: RoutingFees { - base_msat: fee_base_msat, - proportional_millionths: fee_proportional_millionths, - }, next_hops_fee_msat: u64::max_value(), hop_use_fee_msat: u64::max_value(), total_fee_msat: u64::max_value(), @@ -1321,24 +1303,6 @@ where L::Target: Logger { Some(fee_msat) => { hop_use_fee_msat = fee_msat; total_fee_msat += hop_use_fee_msat; - // When calculating the lowest inbound fees to a node, we - // calculate fees here not based on the actual value we think - // will flow over this channel, but on the minimum value that - // we'll accept flowing over it. The minimum accepted value - // is a constant through each path collection run, ensuring - // consistent basis. Otherwise we may later find a - // different path to the source node that is more expensive, - // but which we consider to be cheaper because we are capacity - // constrained and the relative fee becomes lower. - match compute_fees(minimal_value_contribution_msat, old_entry.src_lowest_inbound_fees) - .map(|a| a.checked_add(total_fee_msat)) { - Some(Some(v)) => { - total_fee_msat = v; - }, - _ => { - total_fee_msat = u64::max_value(); - } - }; } } } @@ -1355,8 +1319,7 @@ where L::Target: Logger { .saturating_add(channel_penalty_msat); let new_graph_node = RouteGraphNode { node_id: $src_node_id, - lowest_fee_to_peer_through_node: total_fee_msat, - lowest_fee_to_node: $next_hops_fee_msat as u64 + hop_use_fee_msat, + lowest_fee_to_node: total_fee_msat, total_cltv_delta: hop_total_cltv_delta, value_contribution_msat, path_htlc_minimum_msat, From 1bd35367d8495d9a8b90f5de0f02b68014523e3b Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Tue, 25 Oct 2022 03:50:07 +0000 Subject: [PATCH 4/7] Add a new `IndexedMap` type and use it in network graph storage Our network graph has to be iterable in a deterministic order and with the ability to iterate over a specific range. Thus, historically, we've used a `BTreeMap` to do the iteration. This is fine, except our map needs to also provide high performance lookups in order to make route-finding fast. Sadly, `BTreeMap`s are quite slow due to the branching penalty. Here we replace the `BTreeMap`s in the scorer with a dummy wrapper. In the next commit the internals thereof will be replaced with a `HashMap`-based implementation. --- lightning/src/routing/gossip.rs | 56 +++++------ lightning/src/routing/router.rs | 12 +-- lightning/src/util/indexed_map.rs | 159 ++++++++++++++++++++++++++++++ lightning/src/util/mod.rs | 2 + 4 files changed, 195 insertions(+), 34 deletions(-) create mode 100644 lightning/src/util/indexed_map.rs diff --git a/lightning/src/routing/gossip.rs b/lightning/src/routing/gossip.rs index a12b3d563da..24a21d795b2 100644 --- a/lightning/src/routing/gossip.rs +++ b/lightning/src/routing/gossip.rs @@ -32,11 +32,11 @@ use crate::util::logger::{Logger, Level}; use crate::util::events::{MessageSendEvent, MessageSendEventsProvider}; use crate::util::scid_utils::{block_from_scid, scid_from_parts, MAX_SCID_BLOCK}; use crate::util::string::PrintableString; +use crate::util::indexed_map::{IndexedMap, Entry as IndexedMapEntry}; use crate::io; use crate::io_extras::{copy, sink}; use crate::prelude::*; -use alloc::collections::{BTreeMap, btree_map::Entry as BtreeEntry}; use core::{cmp, fmt}; use crate::sync::{RwLock, RwLockReadGuard}; #[cfg(feature = "std")] @@ -133,8 +133,8 @@ pub struct NetworkGraph where L::Target: Logger { genesis_hash: BlockHash, logger: L, // Lock order: channels -> nodes - channels: RwLock>, - nodes: RwLock>, + channels: RwLock>, + nodes: RwLock>, // Lock order: removed_channels -> removed_nodes // // NOTE: In the following `removed_*` maps, we use seconds since UNIX epoch to track time instead @@ -158,8 +158,8 @@ pub struct NetworkGraph where L::Target: Logger { /// A read-only view of [`NetworkGraph`]. pub struct ReadOnlyNetworkGraph<'a> { - channels: RwLockReadGuard<'a, BTreeMap>, - nodes: RwLockReadGuard<'a, BTreeMap>, + channels: RwLockReadGuard<'a, IndexedMap>, + nodes: RwLockReadGuard<'a, IndexedMap>, } /// Update to the [`NetworkGraph`] based on payment failure information conveyed via the Onion @@ -1131,13 +1131,13 @@ impl Writeable for NetworkGraph where L::Target: Logger { self.genesis_hash.write(writer)?; let channels = self.channels.read().unwrap(); (channels.len() as u64).write(writer)?; - for (ref chan_id, ref chan_info) in channels.iter() { + for (ref chan_id, ref chan_info) in channels.unordered_iter() { (*chan_id).write(writer)?; chan_info.write(writer)?; } let nodes = self.nodes.read().unwrap(); (nodes.len() as u64).write(writer)?; - for (ref node_id, ref node_info) in nodes.iter() { + for (ref node_id, ref node_info) in nodes.unordered_iter() { node_id.write(writer)?; node_info.write(writer)?; } @@ -1156,14 +1156,14 @@ impl ReadableArgs for NetworkGraph where L::Target: Logger { let genesis_hash: BlockHash = Readable::read(reader)?; let channels_count: u64 = Readable::read(reader)?; - let mut channels: BTreeMap = BTreeMap::new(); + let mut channels = IndexedMap::new(); for _ in 0..channels_count { let chan_id: u64 = Readable::read(reader)?; let chan_info = Readable::read(reader)?; channels.insert(chan_id, chan_info); } let nodes_count: u64 = Readable::read(reader)?; - let mut nodes: BTreeMap = BTreeMap::new(); + let mut nodes = IndexedMap::new(); for _ in 0..nodes_count { let node_id = Readable::read(reader)?; let node_info = Readable::read(reader)?; @@ -1191,11 +1191,11 @@ impl ReadableArgs for NetworkGraph where L::Target: Logger { impl fmt::Display for NetworkGraph where L::Target: Logger { fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { writeln!(f, "Network map\n[Channels]")?; - for (key, val) in self.channels.read().unwrap().iter() { + for (key, val) in self.channels.read().unwrap().unordered_iter() { writeln!(f, " {}: {}", key, val)?; } writeln!(f, "[Nodes]")?; - for (&node_id, val) in self.nodes.read().unwrap().iter() { + for (&node_id, val) in self.nodes.read().unwrap().unordered_iter() { writeln!(f, " {}: {}", log_bytes!(node_id.as_slice()), val)?; } Ok(()) @@ -1218,8 +1218,8 @@ impl NetworkGraph where L::Target: Logger { secp_ctx: Secp256k1::verification_only(), genesis_hash, logger, - channels: RwLock::new(BTreeMap::new()), - nodes: RwLock::new(BTreeMap::new()), + channels: RwLock::new(IndexedMap::new()), + nodes: RwLock::new(IndexedMap::new()), last_rapid_gossip_sync_timestamp: Mutex::new(None), removed_channels: Mutex::new(HashMap::new()), removed_nodes: Mutex::new(HashMap::new()), @@ -1252,7 +1252,7 @@ impl NetworkGraph where L::Target: Logger { /// purposes. #[cfg(test)] pub fn clear_nodes_announcement_info(&self) { - for node in self.nodes.write().unwrap().iter_mut() { + for node in self.nodes.write().unwrap().unordered_iter_mut() { node.1.announcement_info = None; } } @@ -1382,7 +1382,7 @@ impl NetworkGraph where L::Target: Logger { let node_id_b = channel_info.node_two.clone(); match channels.entry(short_channel_id) { - BtreeEntry::Occupied(mut entry) => { + IndexedMapEntry::Occupied(mut entry) => { //TODO: because asking the blockchain if short_channel_id is valid is only optional //in the blockchain API, we need to handle it smartly here, though it's unclear //exactly how... @@ -1401,17 +1401,17 @@ impl NetworkGraph where L::Target: Logger { return Err(LightningError{err: "Already have knowledge of channel".to_owned(), action: ErrorAction::IgnoreDuplicateGossip}); } }, - BtreeEntry::Vacant(entry) => { + IndexedMapEntry::Vacant(entry) => { entry.insert(channel_info); } }; for current_node_id in [node_id_a, node_id_b].iter() { match nodes.entry(current_node_id.clone()) { - BtreeEntry::Occupied(node_entry) => { + IndexedMapEntry::Occupied(node_entry) => { node_entry.into_mut().channels.push(short_channel_id); }, - BtreeEntry::Vacant(node_entry) => { + IndexedMapEntry::Vacant(node_entry) => { node_entry.insert(NodeInfo { channels: vec!(short_channel_id), announcement_info: None, @@ -1585,7 +1585,7 @@ impl NetworkGraph where L::Target: Logger { for scid in node.channels.iter() { if let Some(chan_info) = channels.remove(scid) { let other_node_id = if node_id == chan_info.node_one { chan_info.node_two } else { chan_info.node_one }; - if let BtreeEntry::Occupied(mut other_node_entry) = nodes.entry(other_node_id) { + if let IndexedMapEntry::Occupied(mut other_node_entry) = nodes.entry(other_node_id) { other_node_entry.get_mut().channels.retain(|chan_id| { *scid != *chan_id }); @@ -1644,7 +1644,7 @@ impl NetworkGraph where L::Target: Logger { // Sadly BTreeMap::retain was only stabilized in 1.53 so we can't switch to it for some // time. let mut scids_to_remove = Vec::new(); - for (scid, info) in channels.iter_mut() { + for (scid, info) in channels.unordered_iter_mut() { if info.one_to_two.is_some() && info.one_to_two.as_ref().unwrap().last_update < min_time_unix { info.one_to_two = None; } @@ -1813,10 +1813,10 @@ impl NetworkGraph where L::Target: Logger { Ok(()) } - fn remove_channel_in_nodes(nodes: &mut BTreeMap, chan: &ChannelInfo, short_channel_id: u64) { + fn remove_channel_in_nodes(nodes: &mut IndexedMap, chan: &ChannelInfo, short_channel_id: u64) { macro_rules! remove_from_node { ($node_id: expr) => { - if let BtreeEntry::Occupied(mut entry) = nodes.entry($node_id) { + if let IndexedMapEntry::Occupied(mut entry) = nodes.entry($node_id) { entry.get_mut().channels.retain(|chan_id| { short_channel_id != *chan_id }); @@ -1837,8 +1837,8 @@ impl NetworkGraph where L::Target: Logger { impl ReadOnlyNetworkGraph<'_> { /// Returns all known valid channels' short ids along with announced channel info. /// - /// (C-not exported) because we have no mapping for `BTreeMap`s - pub fn channels(&self) -> &BTreeMap { + /// (C-not exported) because we don't want to return lifetime'd references + pub fn channels(&self) -> &IndexedMap { &*self.channels } @@ -1850,13 +1850,13 @@ impl ReadOnlyNetworkGraph<'_> { #[cfg(c_bindings)] // Non-bindings users should use `channels` /// Returns the list of channels in the graph pub fn list_channels(&self) -> Vec { - self.channels.keys().map(|c| *c).collect() + self.channels.unordered_keys().map(|c| *c).collect() } /// Returns all known nodes' public keys along with announced node info. /// - /// (C-not exported) because we have no mapping for `BTreeMap`s - pub fn nodes(&self) -> &BTreeMap { + /// (C-not exported) because we don't want to return lifetime'd references + pub fn nodes(&self) -> &IndexedMap { &*self.nodes } @@ -1868,7 +1868,7 @@ impl ReadOnlyNetworkGraph<'_> { #[cfg(c_bindings)] // Non-bindings users should use `nodes` /// Returns the list of nodes in the graph pub fn list_nodes(&self) -> Vec { - self.nodes.keys().map(|n| *n).collect() + self.nodes.unordered_keys().map(|n| *n).collect() } /// Get network addresses by node id. diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index e4b95a90d3b..eb6eede0e82 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -5507,9 +5507,9 @@ mod tests { 'load_endpoints: for _ in 0..10 { loop { seed = seed.overflowing_mul(0xdeadbeef).0; - let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); seed = seed.overflowing_mul(0xdeadbeef).0; - let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); let payment_params = PaymentParameters::from_node_id(dst); let amt = seed as u64 % 200_000_000; let params = ProbabilisticScoringParameters::default(); @@ -5545,9 +5545,9 @@ mod tests { 'load_endpoints: for _ in 0..10 { loop { seed = seed.overflowing_mul(0xdeadbeef).0; - let src = &PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let src = &PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); seed = seed.overflowing_mul(0xdeadbeef).0; - let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); let payment_params = PaymentParameters::from_node_id(dst).with_features(channelmanager::provided_invoice_features(&config)); let amt = seed as u64 % 200_000_000; let params = ProbabilisticScoringParameters::default(); @@ -5745,9 +5745,9 @@ mod benches { 'load_endpoints: for _ in 0..150 { loop { seed *= 0xdeadbeef; - let src = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let src = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); seed *= 0xdeadbeef; - let dst = PublicKey::from_slice(nodes.keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); + let dst = PublicKey::from_slice(nodes.unordered_keys().skip(seed % nodes.len()).next().unwrap().as_slice()).unwrap(); let params = PaymentParameters::from_node_id(dst).with_features(features.clone()); let first_hop = first_hop(src); let amt = seed as u64 % 1_000_000; diff --git a/lightning/src/util/indexed_map.rs b/lightning/src/util/indexed_map.rs new file mode 100644 index 00000000000..841659714c6 --- /dev/null +++ b/lightning/src/util/indexed_map.rs @@ -0,0 +1,159 @@ +//! This module has a map which can be iterated in a deterministic order. See the [`IndexedMap`]. + +use crate::prelude::HashMap; +use alloc::collections::{BTreeMap, btree_map}; +use core::cmp::Ord; +use core::ops::RangeBounds; + +/// A map which can be iterated in a deterministic order. +/// +/// This would traditionally be accomplished by simply using a [`BTreeMap`], however B-Trees +/// generally have very slow lookups. Because we use a nodes+channels map while finding routes +/// across the network graph, our network graph backing map must be as performant as possible. +/// However, because peers expect to sync the network graph from us (and we need to support that +/// without holding a lock on the graph for the duration of the sync or dumping the entire graph +/// into our outbound message queue), we need an iterable map with a consistent iteration order we +/// can jump to a starting point on. +/// +/// Thus, we have a custom data structure here - its API mimics that of Rust's [`BTreeMap`], but is +/// actually backed by a [`HashMap`], with some additional tracking to ensure we can iterate over +/// keys in the order defined by [`Ord`]. +/// +/// [`BTreeMap`]: alloc::collections::BTreeMap +#[derive(Clone, PartialEq, Eq)] +pub struct IndexedMap { + map: BTreeMap, +} + +impl IndexedMap { + /// Constructs a new, empty map + pub fn new() -> Self { + Self { + map: BTreeMap::new(), + } + } + + #[inline(always)] + /// Fetches the element with the given `key`, if one exists. + pub fn get(&self, key: &K) -> Option<&V> { + self.map.get(key) + } + + /// Fetches a mutable reference to the element with the given `key`, if one exists. + pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { + self.map.get_mut(key) + } + + #[inline] + /// Returns true if an element with the given `key` exists in the map. + pub fn contains_key(&self, key: &K) -> bool { + self.map.contains_key(key) + } + + /// Removes the element with the given `key`, returning it, if one exists. + pub fn remove(&mut self, key: &K) -> Option { + self.map.remove(key) + } + + /// Inserts the given `key`/`value` pair into the map, returning the element that was + /// previously stored at the given `key`, if one exists. + pub fn insert(&mut self, key: K, value: V) -> Option { + self.map.insert(key, value) + } + + /// Returns an [`Entry`] for the given `key` in the map, allowing access to the value. + pub fn entry(&mut self, key: K) -> Entry<'_, K, V> { + match self.map.entry(key) { + btree_map::Entry::Vacant(entry) => { + Entry::Vacant(VacantEntry { + underlying_entry: entry + }) + }, + btree_map::Entry::Occupied(entry) => { + Entry::Occupied(OccupiedEntry { + underlying_entry: entry + }) + } + } + } + + /// Returns an iterator which iterates over the keys in the map, in a random order. + pub fn unordered_keys(&self) -> impl Iterator { + self.map.keys() + } + + /// Returns an iterator which iterates over the `key`/`value` pairs in a random order. + pub fn unordered_iter(&self) -> impl Iterator { + self.map.iter() + } + + /// Returns an iterator which iterates over the `key`s and mutable references to `value`s in a + /// random order. + pub fn unordered_iter_mut(&mut self) -> impl Iterator { + self.map.iter_mut() + } + + /// Returns an iterator which iterates over the `key`/`value` pairs in a given range. + pub fn range>(&self, range: R) -> btree_map::Range { + self.map.range(range) + } + + /// Returns the number of `key`/`value` pairs in the map + pub fn len(&self) -> usize { + self.map.len() + } + + /// Returns true if there are no elements in the map + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } +} + +/// An [`Entry`] for a key which currently has no value +pub struct VacantEntry<'a, K: Ord, V> { + underlying_entry: btree_map::VacantEntry<'a, K, V>, +} + +/// An [`Entry`] for an existing key-value pair +pub struct OccupiedEntry<'a, K: Ord, V> { + underlying_entry: btree_map::OccupiedEntry<'a, K, V>, +} + +/// A mutable reference to a position in the map. This can be used to reference, add, or update the +/// value at a fixed key. +pub enum Entry<'a, K: Ord, V> { + /// A mutable reference to a position within the map where there is no value. + Vacant(VacantEntry<'a, K, V>), + /// A mutable reference to a position within the map where there is currently a value. + Occupied(OccupiedEntry<'a, K, V>), +} + +impl<'a, K: Ord, V> VacantEntry<'a, K, V> { + /// Insert a value into the position described by this entry. + pub fn insert(self, value: V) -> &'a mut V { + self.underlying_entry.insert(value) + } +} + +impl<'a, K: Ord, V> OccupiedEntry<'a, K, V> { + /// Remove the value at the position described by this entry. + pub fn remove_entry(self) -> (K, V) { + self.underlying_entry.remove_entry() + } + + /// Get a reference to the value at the position described by this entry. + pub fn get(&self) -> &V { + self.underlying_entry.get() + } + + /// Get a mutable reference to the value at the position described by this entry. + pub fn get_mut(&mut self) -> &mut V { + self.underlying_entry.get_mut() + } + + /// Consume this entry, returning a mutable reference to the value at the position described by + /// this entry. + pub fn into_mut(self) -> &'a mut V { + self.underlying_entry.into_mut() + } +} diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index 1d46865b601..1673bd07f69 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -40,6 +40,8 @@ pub(crate) mod transaction_utils; pub(crate) mod scid_utils; pub(crate) mod time; +pub mod indexed_map; + /// Logging macro utilities. #[macro_use] pub(crate) mod macro_logger; From 039fa5255da65cd84c42af93e720e3db57a09b38 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 19 Jan 2023 17:59:10 +0000 Subject: [PATCH 5/7] Swap `IndexedMap` implementation for a `HashMap`+B-Tree Our network graph has to be iterable in a deterministic order and with the ability to iterate over a specific range. Thus, historically, we've used a `BTreeMap` to do the iteration. This is fine, except our map needs to also provide high performance lookups in order to make route-finding fast. Sadly, `BTreeMap`s are quite slow due to the branching penalty. Here we replace the implementation of our `IndexedMap` with a `HashMap` to store the elements itself and a `BTreeSet` to store the keys set in sorted order for iteration. As of this commit on the same hardware as the above few commits, the benchmark results are: ``` test routing::router::benches::generate_mpp_routes_with_probabilistic_scorer ... bench: 109,544,993 ns/iter (+/- 27,553,574) test routing::router::benches::generate_mpp_routes_with_zero_penalty_scorer ... bench: 81,164,590 ns/iter (+/- 55,422,930) test routing::router::benches::generate_routes_with_probabilistic_scorer ... bench: 34,726,569 ns/iter (+/- 9,646,345) test routing::router::benches::generate_routes_with_zero_penalty_scorer ... bench: 22,772,355 ns/iter (+/- 9,574,418) ``` --- lightning/src/util/indexed_map.rs | 92 +++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 24 deletions(-) diff --git a/lightning/src/util/indexed_map.rs b/lightning/src/util/indexed_map.rs index 841659714c6..cccbfe7bc7a 100644 --- a/lightning/src/util/indexed_map.rs +++ b/lightning/src/util/indexed_map.rs @@ -1,7 +1,8 @@ //! This module has a map which can be iterated in a deterministic order. See the [`IndexedMap`]. -use crate::prelude::HashMap; -use alloc::collections::{BTreeMap, btree_map}; +use crate::prelude::{HashMap, hash_map}; +use alloc::collections::{BTreeSet, btree_set}; +use core::hash::Hash; use core::cmp::Ord; use core::ops::RangeBounds; @@ -20,16 +21,19 @@ use core::ops::RangeBounds; /// keys in the order defined by [`Ord`]. /// /// [`BTreeMap`]: alloc::collections::BTreeMap -#[derive(Clone, PartialEq, Eq)] -pub struct IndexedMap { - map: BTreeMap, +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct IndexedMap { + map: HashMap, + // TODO: Explore swapping this for a sorted vec (that is only sorted on first range() call) + keys: BTreeSet, } -impl IndexedMap { +impl IndexedMap { /// Constructs a new, empty map pub fn new() -> Self { Self { - map: BTreeMap::new(), + map: HashMap::new(), + keys: BTreeSet::new(), } } @@ -52,26 +56,37 @@ impl IndexedMap { /// Removes the element with the given `key`, returning it, if one exists. pub fn remove(&mut self, key: &K) -> Option { - self.map.remove(key) + let ret = self.map.remove(key); + if let Some(_) = ret { + assert!(self.keys.remove(key), "map and keys must be consistent"); + } + ret } /// Inserts the given `key`/`value` pair into the map, returning the element that was /// previously stored at the given `key`, if one exists. pub fn insert(&mut self, key: K, value: V) -> Option { - self.map.insert(key, value) + let ret = self.map.insert(key.clone(), value); + if ret.is_none() { + assert!(self.keys.insert(key), "map and keys must be consistent"); + } + ret } /// Returns an [`Entry`] for the given `key` in the map, allowing access to the value. pub fn entry(&mut self, key: K) -> Entry<'_, K, V> { - match self.map.entry(key) { - btree_map::Entry::Vacant(entry) => { + match self.map.entry(key.clone()) { + hash_map::Entry::Vacant(entry) => { Entry::Vacant(VacantEntry { - underlying_entry: entry + underlying_entry: entry, + key, + keys: &mut self.keys, }) }, - btree_map::Entry::Occupied(entry) => { + hash_map::Entry::Occupied(entry) => { Entry::Occupied(OccupiedEntry { - underlying_entry: entry + underlying_entry: entry, + keys: &mut self.keys, }) } } @@ -94,8 +109,11 @@ impl IndexedMap { } /// Returns an iterator which iterates over the `key`/`value` pairs in a given range. - pub fn range>(&self, range: R) -> btree_map::Range { - self.map.range(range) + pub fn range>(&self, range: R) -> Range { + Range { + inner_range: self.keys.range(range), + map: &self.map, + } } /// Returns the number of `key`/`value` pairs in the map @@ -109,36 +127,62 @@ impl IndexedMap { } } +/// An iterator over a range of values in an [`IndexedMap`] +pub struct Range<'a, K: Hash + Ord, V> { + inner_range: btree_set::Range<'a, K>, + map: &'a HashMap, +} +impl<'a, K: Hash + Ord, V: 'a> Iterator for Range<'a, K, V> { + type Item = (&'a K, &'a V); + fn next(&mut self) -> Option<(&'a K, &'a V)> { + self.inner_range.next().map(|k| { + (k, self.map.get(k).expect("map and keys must be consistent")) + }) + } +} + /// An [`Entry`] for a key which currently has no value -pub struct VacantEntry<'a, K: Ord, V> { - underlying_entry: btree_map::VacantEntry<'a, K, V>, +pub struct VacantEntry<'a, K: Hash + Ord, V> { + #[cfg(feature = "hashbrown")] + underlying_entry: hash_map::VacantEntry<'a, K, V, hash_map::DefaultHashBuilder>, + #[cfg(not(feature = "hashbrown"))] + underlying_entry: hash_map::VacantEntry<'a, K, V>, + key: K, + keys: &'a mut BTreeSet, } /// An [`Entry`] for an existing key-value pair -pub struct OccupiedEntry<'a, K: Ord, V> { - underlying_entry: btree_map::OccupiedEntry<'a, K, V>, +pub struct OccupiedEntry<'a, K: Hash + Ord, V> { + #[cfg(feature = "hashbrown")] + underlying_entry: hash_map::OccupiedEntry<'a, K, V, hash_map::DefaultHashBuilder>, + #[cfg(not(feature = "hashbrown"))] + underlying_entry: hash_map::OccupiedEntry<'a, K, V>, + keys: &'a mut BTreeSet, } /// A mutable reference to a position in the map. This can be used to reference, add, or update the /// value at a fixed key. -pub enum Entry<'a, K: Ord, V> { +pub enum Entry<'a, K: Hash + Ord, V> { /// A mutable reference to a position within the map where there is no value. Vacant(VacantEntry<'a, K, V>), /// A mutable reference to a position within the map where there is currently a value. Occupied(OccupiedEntry<'a, K, V>), } -impl<'a, K: Ord, V> VacantEntry<'a, K, V> { +impl<'a, K: Hash + Ord, V> VacantEntry<'a, K, V> { /// Insert a value into the position described by this entry. pub fn insert(self, value: V) -> &'a mut V { + assert!(self.keys.insert(self.key), "map and keys must be consistent"); self.underlying_entry.insert(value) } } -impl<'a, K: Ord, V> OccupiedEntry<'a, K, V> { +impl<'a, K: Hash + Ord, V> OccupiedEntry<'a, K, V> { /// Remove the value at the position described by this entry. pub fn remove_entry(self) -> (K, V) { - self.underlying_entry.remove_entry() + let res = self.underlying_entry.remove_entry(); + assert!(self.keys.remove(&res.0), "map and keys must be consistent"); + res } /// Get a reference to the value at the position described by this entry. From e64b5d9d2e6252954dce095989c1d55e7003299a Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Thu, 19 Jan 2023 20:24:22 +0000 Subject: [PATCH 6/7] Add a fuzzer to check that `IndexedMap` is equivalent to `BTreeMap` --- fuzz/src/bin/gen_target.sh | 1 + fuzz/src/bin/indexedmap_target.rs | 113 +++++++++++++++++ fuzz/src/bin/msg_channel_details_target.rs | 113 +++++++++++++++++ fuzz/src/indexedmap.rs | 137 +++++++++++++++++++++ fuzz/src/lib.rs | 1 + fuzz/targets.h | 1 + 6 files changed, 366 insertions(+) create mode 100644 fuzz/src/bin/indexedmap_target.rs create mode 100644 fuzz/src/bin/msg_channel_details_target.rs create mode 100644 fuzz/src/indexedmap.rs diff --git a/fuzz/src/bin/gen_target.sh b/fuzz/src/bin/gen_target.sh index 95e65695eb8..fa29540f96b 100755 --- a/fuzz/src/bin/gen_target.sh +++ b/fuzz/src/bin/gen_target.sh @@ -14,6 +14,7 @@ GEN_TEST peer_crypt GEN_TEST process_network_graph GEN_TEST router GEN_TEST zbase32 +GEN_TEST indexedmap GEN_TEST msg_accept_channel msg_targets:: GEN_TEST msg_announcement_signatures msg_targets:: diff --git a/fuzz/src/bin/indexedmap_target.rs b/fuzz/src/bin/indexedmap_target.rs new file mode 100644 index 00000000000..238566d5465 --- /dev/null +++ b/fuzz/src/bin/indexedmap_target.rs @@ -0,0 +1,113 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +// This file is auto-generated by gen_target.sh based on target_template.txt +// To modify it, modify target_template.txt and run gen_target.sh instead. + +#![cfg_attr(feature = "libfuzzer_fuzz", no_main)] + +#[cfg(not(fuzzing))] +compile_error!("Fuzz targets need cfg=fuzzing"); + +extern crate lightning_fuzz; +use lightning_fuzz::indexedmap::*; + +#[cfg(feature = "afl")] +#[macro_use] extern crate afl; +#[cfg(feature = "afl")] +fn main() { + fuzz!(|data| { + indexedmap_run(data.as_ptr(), data.len()); + }); +} + +#[cfg(feature = "honggfuzz")] +#[macro_use] extern crate honggfuzz; +#[cfg(feature = "honggfuzz")] +fn main() { + loop { + fuzz!(|data| { + indexedmap_run(data.as_ptr(), data.len()); + }); + } +} + +#[cfg(feature = "libfuzzer_fuzz")] +#[macro_use] extern crate libfuzzer_sys; +#[cfg(feature = "libfuzzer_fuzz")] +fuzz_target!(|data: &[u8]| { + indexedmap_run(data.as_ptr(), data.len()); +}); + +#[cfg(feature = "stdin_fuzz")] +fn main() { + use std::io::Read; + + let mut data = Vec::with_capacity(8192); + std::io::stdin().read_to_end(&mut data).unwrap(); + indexedmap_run(data.as_ptr(), data.len()); +} + +#[test] +fn run_test_cases() { + use std::fs; + use std::io::Read; + use lightning_fuzz::utils::test_logger::StringBuffer; + + use std::sync::{atomic, Arc}; + { + let data: Vec = vec![0]; + indexedmap_run(data.as_ptr(), data.len()); + } + let mut threads = Vec::new(); + let threads_running = Arc::new(atomic::AtomicUsize::new(0)); + if let Ok(tests) = fs::read_dir("test_cases/indexedmap") { + for test in tests { + let mut data: Vec = Vec::new(); + let path = test.unwrap().path(); + fs::File::open(&path).unwrap().read_to_end(&mut data).unwrap(); + threads_running.fetch_add(1, atomic::Ordering::AcqRel); + + let thread_count_ref = Arc::clone(&threads_running); + let main_thread_ref = std::thread::current(); + threads.push((path.file_name().unwrap().to_str().unwrap().to_string(), + std::thread::spawn(move || { + let string_logger = StringBuffer::new(); + + let panic_logger = string_logger.clone(); + let res = if ::std::panic::catch_unwind(move || { + indexedmap_test(&data, panic_logger); + }).is_err() { + Some(string_logger.into_string()) + } else { None }; + thread_count_ref.fetch_sub(1, atomic::Ordering::AcqRel); + main_thread_ref.unpark(); + res + }) + )); + while threads_running.load(atomic::Ordering::Acquire) > 32 { + std::thread::park(); + } + } + } + let mut failed_outputs = Vec::new(); + for (test, thread) in threads.drain(..) { + if let Some(output) = thread.join().unwrap() { + println!("\nOutput of {}:\n{}\n", test, output); + failed_outputs.push(test); + } + } + if !failed_outputs.is_empty() { + println!("Test cases which failed: "); + for case in failed_outputs { + println!("{}", case); + } + panic!(); + } +} diff --git a/fuzz/src/bin/msg_channel_details_target.rs b/fuzz/src/bin/msg_channel_details_target.rs new file mode 100644 index 00000000000..cb5021aedfa --- /dev/null +++ b/fuzz/src/bin/msg_channel_details_target.rs @@ -0,0 +1,113 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +// This file is auto-generated by gen_target.sh based on target_template.txt +// To modify it, modify target_template.txt and run gen_target.sh instead. + +#![cfg_attr(feature = "libfuzzer_fuzz", no_main)] + +#[cfg(not(fuzzing))] +compile_error!("Fuzz targets need cfg=fuzzing"); + +extern crate lightning_fuzz; +use lightning_fuzz::msg_targets::msg_channel_details::*; + +#[cfg(feature = "afl")] +#[macro_use] extern crate afl; +#[cfg(feature = "afl")] +fn main() { + fuzz!(|data| { + msg_channel_details_run(data.as_ptr(), data.len()); + }); +} + +#[cfg(feature = "honggfuzz")] +#[macro_use] extern crate honggfuzz; +#[cfg(feature = "honggfuzz")] +fn main() { + loop { + fuzz!(|data| { + msg_channel_details_run(data.as_ptr(), data.len()); + }); + } +} + +#[cfg(feature = "libfuzzer_fuzz")] +#[macro_use] extern crate libfuzzer_sys; +#[cfg(feature = "libfuzzer_fuzz")] +fuzz_target!(|data: &[u8]| { + msg_channel_details_run(data.as_ptr(), data.len()); +}); + +#[cfg(feature = "stdin_fuzz")] +fn main() { + use std::io::Read; + + let mut data = Vec::with_capacity(8192); + std::io::stdin().read_to_end(&mut data).unwrap(); + msg_channel_details_run(data.as_ptr(), data.len()); +} + +#[test] +fn run_test_cases() { + use std::fs; + use std::io::Read; + use lightning_fuzz::utils::test_logger::StringBuffer; + + use std::sync::{atomic, Arc}; + { + let data: Vec = vec![0]; + msg_channel_details_run(data.as_ptr(), data.len()); + } + let mut threads = Vec::new(); + let threads_running = Arc::new(atomic::AtomicUsize::new(0)); + if let Ok(tests) = fs::read_dir("test_cases/msg_channel_details") { + for test in tests { + let mut data: Vec = Vec::new(); + let path = test.unwrap().path(); + fs::File::open(&path).unwrap().read_to_end(&mut data).unwrap(); + threads_running.fetch_add(1, atomic::Ordering::AcqRel); + + let thread_count_ref = Arc::clone(&threads_running); + let main_thread_ref = std::thread::current(); + threads.push((path.file_name().unwrap().to_str().unwrap().to_string(), + std::thread::spawn(move || { + let string_logger = StringBuffer::new(); + + let panic_logger = string_logger.clone(); + let res = if ::std::panic::catch_unwind(move || { + msg_channel_details_test(&data, panic_logger); + }).is_err() { + Some(string_logger.into_string()) + } else { None }; + thread_count_ref.fetch_sub(1, atomic::Ordering::AcqRel); + main_thread_ref.unpark(); + res + }) + )); + while threads_running.load(atomic::Ordering::Acquire) > 32 { + std::thread::park(); + } + } + } + let mut failed_outputs = Vec::new(); + for (test, thread) in threads.drain(..) { + if let Some(output) = thread.join().unwrap() { + println!("\nOutput of {}:\n{}\n", test, output); + failed_outputs.push(test); + } + } + if !failed_outputs.is_empty() { + println!("Test cases which failed: "); + for case in failed_outputs { + println!("{}", case); + } + panic!(); + } +} diff --git a/fuzz/src/indexedmap.rs b/fuzz/src/indexedmap.rs new file mode 100644 index 00000000000..795d6175bb5 --- /dev/null +++ b/fuzz/src/indexedmap.rs @@ -0,0 +1,137 @@ +// This file is Copyright its original authors, visible in version control +// history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license +// , at your option. +// You may not use this file except in accordance with one or both of these +// licenses. + +use lightning::util::indexed_map::{IndexedMap, self}; +use std::collections::{BTreeMap, btree_map}; +use hashbrown::HashSet; + +use crate::utils::test_logger; + +fn check_eq(btree: &BTreeMap, indexed: &IndexedMap) { + assert_eq!(btree.len(), indexed.len()); + assert_eq!(btree.is_empty(), indexed.is_empty()); + + let mut btree_clone = btree.clone(); + assert!(btree_clone == *btree); + let mut indexed_clone = indexed.clone(); + assert!(indexed_clone == *indexed); + + for k in 0..=255 { + assert_eq!(btree.contains_key(&k), indexed.contains_key(&k)); + assert_eq!(btree.get(&k), indexed.get(&k)); + + let btree_entry = btree_clone.entry(k); + let indexed_entry = indexed_clone.entry(k); + match btree_entry { + btree_map::Entry::Occupied(mut bo) => { + if let indexed_map::Entry::Occupied(mut io) = indexed_entry { + assert_eq!(bo.get(), io.get()); + assert_eq!(bo.get_mut(), io.get_mut()); + } else { panic!(); } + }, + btree_map::Entry::Vacant(_) => { + if let indexed_map::Entry::Vacant(_) = indexed_entry { + } else { panic!(); } + } + } + } + + const STRIDE: u8 = 16; + for k in 0..=255/STRIDE { + let lower_bound = k * STRIDE; + let upper_bound = lower_bound + (STRIDE - 1); + let mut btree_iter = btree.range(lower_bound..=upper_bound); + let mut indexed_iter = indexed.range(lower_bound..=upper_bound); + loop { + let b_v = btree_iter.next(); + let i_v = indexed_iter.next(); + assert_eq!(b_v, i_v); + if b_v.is_none() { break; } + } + } + + let mut key_set = HashSet::with_capacity(256); + for k in indexed.unordered_keys() { + assert!(key_set.insert(*k)); + assert!(btree.contains_key(k)); + } + assert_eq!(key_set.len(), btree.len()); + + key_set.clear(); + for (k, v) in indexed.unordered_iter() { + assert!(key_set.insert(*k)); + assert_eq!(btree.get(k).unwrap(), v); + } + assert_eq!(key_set.len(), btree.len()); + + key_set.clear(); + for (k, v) in indexed_clone.unordered_iter_mut() { + assert!(key_set.insert(*k)); + assert_eq!(btree.get(k).unwrap(), v); + } + assert_eq!(key_set.len(), btree.len()); +} + +#[inline] +pub fn do_test(data: &[u8]) { + if data.len() % 2 != 0 { return; } + let mut btree = BTreeMap::new(); + let mut indexed = IndexedMap::new(); + + // Read in k-v pairs from the input and insert them into the maps then check that the maps are + // equivalent in every way we can read them. + for tuple in data.windows(2) { + let prev_value_b = btree.insert(tuple[0], tuple[1]); + let prev_value_i = indexed.insert(tuple[0], tuple[1]); + assert_eq!(prev_value_b, prev_value_i); + } + check_eq(&btree, &indexed); + + // Now, modify the maps in all the ways we have to do so, checking that the maps remain + // equivalent as we go. + for (k, v) in indexed.unordered_iter_mut() { + *v = *k; + *btree.get_mut(k).unwrap() = *k; + } + check_eq(&btree, &indexed); + + for k in 0..=255 { + match btree.entry(k) { + btree_map::Entry::Occupied(mut bo) => { + if let indexed_map::Entry::Occupied(mut io) = indexed.entry(k) { + if k < 64 { + *io.get_mut() ^= 0xff; + *bo.get_mut() ^= 0xff; + } else if k < 128 { + *io.into_mut() ^= 0xff; + *bo.get_mut() ^= 0xff; + } else { + assert_eq!(bo.remove_entry(), io.remove_entry()); + } + } else { panic!(); } + }, + btree_map::Entry::Vacant(bv) => { + if let indexed_map::Entry::Vacant(iv) = indexed.entry(k) { + bv.insert(k); + iv.insert(k); + } else { panic!(); } + }, + } + } + check_eq(&btree, &indexed); +} + +pub fn indexedmap_test(data: &[u8], _out: Out) { + do_test(data); +} + +#[no_mangle] +pub extern "C" fn indexedmap_run(data: *const u8, datalen: usize) { + do_test(unsafe { std::slice::from_raw_parts(data, datalen) }); +} diff --git a/fuzz/src/lib.rs b/fuzz/src/lib.rs index 2238a9702a9..462307d55b4 100644 --- a/fuzz/src/lib.rs +++ b/fuzz/src/lib.rs @@ -17,6 +17,7 @@ pub mod utils; pub mod chanmon_deser; pub mod chanmon_consistency; pub mod full_stack; +pub mod indexedmap; pub mod onion_message; pub mod peer_crypt; pub mod process_network_graph; diff --git a/fuzz/targets.h b/fuzz/targets.h index cff3f9bdbb5..5bfee07dafb 100644 --- a/fuzz/targets.h +++ b/fuzz/targets.h @@ -7,6 +7,7 @@ void peer_crypt_run(const unsigned char* data, size_t data_len); void process_network_graph_run(const unsigned char* data, size_t data_len); void router_run(const unsigned char* data, size_t data_len); void zbase32_run(const unsigned char* data, size_t data_len); +void indexedmap_run(const unsigned char* data, size_t data_len); void msg_accept_channel_run(const unsigned char* data, size_t data_len); void msg_announcement_signatures_run(const unsigned char* data, size_t data_len); void msg_channel_reestablish_run(const unsigned char* data, size_t data_len); From bde841e928da56354d267cb329eea2ee9359e3e3 Mon Sep 17 00:00:00 2001 From: Matt Corallo Date: Wed, 25 Jan 2023 17:42:20 +0000 Subject: [PATCH 7/7] Clean up `compute_fees` and add a saturating variant Often when we call `compute_fees` we really just want it to saturate and we deal with `u64::max_value` later. In that case, we're much better off doing the saturating in the `compute_fees` as it can use CMOVs rather than branching at each step and then `unwrap_or`ing at the callsite. --- lightning/src/routing/router.rs | 45 +++++++++++++++------------------ 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index eb6eede0e82..8543956ac65 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -885,18 +885,20 @@ impl<'a> PaymentPath<'a> { } } +#[inline(always)] +/// Calculate the fees required to route the given amount over a channel with the given fees. fn compute_fees(amount_msat: u64, channel_fees: RoutingFees) -> Option { - let proportional_fee_millions = - amount_msat.checked_mul(channel_fees.proportional_millionths as u64); - if let Some(new_fee) = proportional_fee_millions.and_then(|part| { - (channel_fees.base_msat as u64).checked_add(part / 1_000_000) }) { + amount_msat.checked_mul(channel_fees.proportional_millionths as u64) + .and_then(|part| (channel_fees.base_msat as u64).checked_add(part / 1_000_000)) +} - Some(new_fee) - } else { - // This function may be (indirectly) called without any verification, - // with channel_fees provided by a caller. We should handle it gracefully. - None - } +#[inline(always)] +/// Calculate the fees required to route the given amount over a channel with the given fees, +/// saturating to [`u64::max_value`]. +fn compute_fees_saturating(amount_msat: u64, channel_fees: RoutingFees) -> u64 { + amount_msat.checked_mul(channel_fees.proportional_millionths as u64) + .map(|prop| prop / 1_000_000).unwrap_or(u64::max_value()) + .saturating_add(channel_fees.base_msat as u64) } /// The default `features` we assume for a node in a route, when no `features` are known about that @@ -1254,10 +1256,10 @@ where L::Target: Logger { // might violate htlc_minimum_msat on the hops which are next along the // payment path (upstream to the payee). To avoid that, we recompute // path fees knowing the final path contribution after constructing it. - let path_htlc_minimum_msat = compute_fees($next_hops_path_htlc_minimum_msat, $candidate.fees()) - .and_then(|fee_msat| fee_msat.checked_add($next_hops_path_htlc_minimum_msat)) - .map(|fee_msat| cmp::max(fee_msat, $candidate.htlc_minimum_msat())) - .unwrap_or_else(|| u64::max_value()); + let path_htlc_minimum_msat = cmp::max( + compute_fees_saturating($next_hops_path_htlc_minimum_msat, $candidate.fees()) + .saturating_add($next_hops_path_htlc_minimum_msat), + $candidate.htlc_minimum_msat()); let hm_entry = dist.entry($src_node_id); let old_entry = hm_entry.or_insert_with(|| { // If there was previously no known way to access the source node @@ -1291,20 +1293,15 @@ where L::Target: Logger { if should_process { let mut hop_use_fee_msat = 0; - let mut total_fee_msat = $next_hops_fee_msat; + let mut total_fee_msat: u64 = $next_hops_fee_msat; // Ignore hop_use_fee_msat for channel-from-us as we assume all channels-from-us // will have the same effective-fee if $src_node_id != our_node_id { - match compute_fees(amount_to_transfer_over_msat, $candidate.fees()) { - // max_value means we'll always fail - // the old_entry.total_fee_msat > total_fee_msat check - None => total_fee_msat = u64::max_value(), - Some(fee_msat) => { - hop_use_fee_msat = fee_msat; - total_fee_msat += hop_use_fee_msat; - } - } + // Note that `u64::max_value` means we'll always fail the + // `old_entry.total_fee_msat > total_fee_msat` check below + hop_use_fee_msat = compute_fees_saturating(amount_to_transfer_over_msat, $candidate.fees()); + total_fee_msat = total_fee_msat.saturating_add(hop_use_fee_msat); } let channel_usage = ChannelUsage {