diff --git a/.github/workflows/semver.yml b/.github/workflows/semver.yml new file mode 100644 index 00000000000..73723fe34c3 --- /dev/null +++ b/.github/workflows/semver.yml @@ -0,0 +1,57 @@ +name: SemVer checks +on: + push: + branches-ignore: + - master + pull_request: + branches-ignore: + - master + +jobs: + semver-checks: + runs-on: ubuntu-latest + steps: + - name: Checkout source code + uses: actions/checkout@v4 + - name: Install Rust stable toolchain + run: | + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --profile=minimal --default-toolchain stable + rustup override set stable + - name: Check SemVer with default features + uses: obi1kenobi/cargo-semver-checks-action@v2 + with: + feature-group: default-features + - name: Check SemVer *without* default features + uses: obi1kenobi/cargo-semver-checks-action@v2 + with: + feature-group: only-explicit-features + - name: Check lightning-background-processor SemVer + uses: obi1kenobi/cargo-semver-checks-action@v2 + with: + package: lightning-background-processor + feature-group: only-explicit-features + features: futures + - name: Check lightning-block-sync SemVer + uses: obi1kenobi/cargo-semver-checks-action@v2 + with: + package: lightning-block-sync + feature-group: only-explicit-features + features: rpc-client,rest-client + - name: Check lightning-transaction-sync electrum SemVer + uses: obi1kenobi/cargo-semver-checks-action@v2 + with: + manifest-path: lightning-transaction-sync/Cargo.toml + feature-group: only-explicit-features + features: electrum + - name: Check lightning-transaction-sync esplora-blocking SemVer + uses: obi1kenobi/cargo-semver-checks-action@v2 + with: + manifest-path: lightning-transaction-sync/Cargo.toml + feature-group: only-explicit-features + features: esplora-blocking + - name: Check lightning-transaction-sync esplora-async SemVer + uses: obi1kenobi/cargo-semver-checks-action@v2 + with: + manifest-path: lightning-transaction-sync/Cargo.toml + feature-group: only-explicit-features + features: esplora-async diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index adcae564f56..ef889d4e80f 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -2379,8 +2379,8 @@ mod tests { 42, 53, features, - $nodes[0].node.get_our_node_id(), - $nodes[1].node.get_our_node_id(), + $nodes[0].node.get_our_node_id().into(), + $nodes[1].node.get_our_node_id().into(), ) .expect("Failed to update channel from partial announcement"); let original_graph_description = $nodes[0].network_graph.to_string(); diff --git a/lightning/src/chain/channelmonitor.rs b/lightning/src/chain/channelmonitor.rs index 8788fbb894d..4c195b20a78 100644 --- a/lightning/src/chain/channelmonitor.rs +++ b/lightning/src/chain/channelmonitor.rs @@ -1509,8 +1509,8 @@ impl ChannelMonitor { fn provide_latest_holder_commitment_tx( &self, holder_commitment_tx: HolderCommitmentTransaction, htlc_outputs: Vec<(HTLCOutputInCommitment, Option, Option)>, - ) -> Result<(), ()> { - self.inner.lock().unwrap().provide_latest_holder_commitment_tx(holder_commitment_tx, htlc_outputs, &Vec::new(), Vec::new()).map_err(|_| ()) + ) { + self.inner.lock().unwrap().provide_latest_holder_commitment_tx(holder_commitment_tx, htlc_outputs, &Vec::new(), Vec::new()) } /// This is used to provide payment preimage(s) out-of-band during startup without updating the @@ -1737,10 +1737,14 @@ impl ChannelMonitor { self.inner.lock().unwrap().get_cur_holder_commitment_number() } - /// Gets whether we've been notified that this channel is closed by the `ChannelManager` (i.e. - /// via a [`ChannelMonitorUpdateStep::ChannelForceClosed`]). - pub(crate) fn offchain_closed(&self) -> bool { - self.inner.lock().unwrap().lockdown_from_offchain + /// Fetches whether this monitor has marked the channel as closed and will refuse any further + /// updates to the commitment transactions. + /// + /// It can be marked closed in a few different ways, including via a + /// [`ChannelMonitorUpdateStep::ChannelForceClosed`] or if the channel has been closed + /// on-chain. + pub(crate) fn no_further_updates_allowed(&self) -> bool { + self.inner.lock().unwrap().no_further_updates_allowed() } /// Gets the `node_id` of the counterparty for this channel. @@ -2901,7 +2905,7 @@ impl ChannelMonitorImpl { /// is important that any clones of this channel monitor (including remote clones) by kept /// up-to-date as our holder commitment transaction is updated. /// Panics if set_on_holder_tx_csv has never been called. - fn provide_latest_holder_commitment_tx(&mut self, holder_commitment_tx: HolderCommitmentTransaction, mut htlc_outputs: Vec<(HTLCOutputInCommitment, Option, Option)>, claimed_htlcs: &[(SentHTLCId, PaymentPreimage)], nondust_htlc_sources: Vec) -> Result<(), &'static str> { + fn provide_latest_holder_commitment_tx(&mut self, holder_commitment_tx: HolderCommitmentTransaction, mut htlc_outputs: Vec<(HTLCOutputInCommitment, Option, Option)>, claimed_htlcs: &[(SentHTLCId, PaymentPreimage)], nondust_htlc_sources: Vec) { if htlc_outputs.iter().any(|(_, s, _)| s.is_some()) { // If we have non-dust HTLCs in htlc_outputs, ensure they match the HTLCs in the // `holder_commitment_tx`. In the future, we'll no longer provide the redundant data @@ -2978,10 +2982,6 @@ impl ChannelMonitorImpl { } self.counterparty_fulfilled_htlcs.insert(*claimed_htlc_id, *claimed_preimage); } - if self.holder_tx_signed { - return Err("Latest holder commitment signed has already been signed, update is rejected"); - } - Ok(()) } /// Provides a payment_hash->payment_preimage mapping. Will be automatically pruned when all @@ -3202,11 +3202,7 @@ impl ChannelMonitorImpl { ChannelMonitorUpdateStep::LatestHolderCommitmentTXInfo { commitment_tx, htlc_outputs, claimed_htlcs, nondust_htlc_sources } => { log_trace!(logger, "Updating ChannelMonitor with latest holder commitment transaction info"); if self.lockdown_from_offchain { panic!(); } - if let Err(e) = self.provide_latest_holder_commitment_tx(commitment_tx.clone(), htlc_outputs.clone(), &claimed_htlcs, nondust_htlc_sources.clone()) { - log_error!(logger, "Providing latest holder commitment transaction failed/was refused:"); - log_error!(logger, " {}", e); - ret = Err(()); - } + self.provide_latest_holder_commitment_tx(commitment_tx.clone(), htlc_outputs.clone(), &claimed_htlcs, nondust_htlc_sources.clone()); } ChannelMonitorUpdateStep::LatestCounterpartyCommitmentTXInfo { commitment_txid, htlc_outputs, commitment_number, their_per_commitment_point, .. } => { log_trace!(logger, "Updating ChannelMonitor with latest counterparty commitment transaction info"); @@ -3286,12 +3282,16 @@ impl ChannelMonitorImpl { } } - if ret.is_ok() && (self.funding_spend_seen || self.lockdown_from_offchain || self.holder_tx_signed) && is_pre_close_update { + if ret.is_ok() && self.no_further_updates_allowed() && is_pre_close_update { log_error!(logger, "Refusing Channel Monitor Update as counterparty attempted to update commitment after funding was spent"); Err(()) } else { ret } } + fn no_further_updates_allowed(&self) -> bool { + self.funding_spend_seen || self.lockdown_from_offchain || self.holder_tx_signed + } + fn get_latest_update_id(&self) -> u64 { self.latest_update_id } @@ -3564,11 +3564,16 @@ impl ChannelMonitorImpl { return (claimable_outpoints, to_counterparty_output_info); } let revk_htlc_outp = RevokedHTLCOutput::build(per_commitment_point, self.counterparty_commitment_params.counterparty_delayed_payment_base_key, self.counterparty_commitment_params.counterparty_htlc_base_key, per_commitment_key, htlc.amount_msat / 1000, htlc.clone(), &self.onchain_tx_handler.channel_transaction_parameters.channel_type_features); + let counterparty_spendable_height = if htlc.offered { + htlc.cltv_expiry + } else { + height + }; let justice_package = PackageTemplate::build_package( commitment_txid, transaction_output_index, PackageSolvingData::RevokedHTLCOutput(revk_htlc_outp), - htlc.cltv_expiry, + counterparty_spendable_height, ); claimable_outpoints.push(justice_package); } @@ -3869,35 +3874,32 @@ impl ChannelMonitorImpl { } } } - if self.holder_tx_signed { - // If we've signed, we may have broadcast either commitment (prev or current), and - // attempted to claim from it immediately without waiting for a confirmation. - if self.current_holder_commitment_tx.txid != *confirmed_commitment_txid { + // Cancel any pending claims for any holder commitments in case they had previously + // confirmed or been signed (in which case we will start attempting to claim without + // waiting for confirmation). + if self.current_holder_commitment_tx.txid != *confirmed_commitment_txid { + log_trace!(logger, "Canceling claims for previously broadcast holder commitment {}", + self.current_holder_commitment_tx.txid); + let mut outpoint = BitcoinOutPoint { txid: self.current_holder_commitment_tx.txid, vout: 0 }; + for (htlc, _, _) in &self.current_holder_commitment_tx.htlc_outputs { + if let Some(vout) = htlc.transaction_output_index { + outpoint.vout = vout; + self.onchain_tx_handler.abandon_claim(&outpoint); + } + } + } + if let Some(prev_holder_commitment_tx) = &self.prev_holder_signed_commitment_tx { + if prev_holder_commitment_tx.txid != *confirmed_commitment_txid { log_trace!(logger, "Canceling claims for previously broadcast holder commitment {}", - self.current_holder_commitment_tx.txid); - let mut outpoint = BitcoinOutPoint { txid: self.current_holder_commitment_tx.txid, vout: 0 }; - for (htlc, _, _) in &self.current_holder_commitment_tx.htlc_outputs { + prev_holder_commitment_tx.txid); + let mut outpoint = BitcoinOutPoint { txid: prev_holder_commitment_tx.txid, vout: 0 }; + for (htlc, _, _) in &prev_holder_commitment_tx.htlc_outputs { if let Some(vout) = htlc.transaction_output_index { outpoint.vout = vout; self.onchain_tx_handler.abandon_claim(&outpoint); } } } - if let Some(prev_holder_commitment_tx) = &self.prev_holder_signed_commitment_tx { - if prev_holder_commitment_tx.txid != *confirmed_commitment_txid { - log_trace!(logger, "Canceling claims for previously broadcast holder commitment {}", - prev_holder_commitment_tx.txid); - let mut outpoint = BitcoinOutPoint { txid: prev_holder_commitment_tx.txid, vout: 0 }; - for (htlc, _, _) in &prev_holder_commitment_tx.htlc_outputs { - if let Some(vout) = htlc.transaction_output_index { - outpoint.vout = vout; - self.onchain_tx_handler.abandon_claim(&outpoint); - } - } - } - } - } else { - // No previous claim. } } @@ -4233,7 +4235,7 @@ impl ChannelMonitorImpl { } } - if self.lockdown_from_offchain || self.funding_spend_seen || self.holder_tx_signed { + if self.no_further_updates_allowed() { // Fail back HTLCs on backwards channels if they expire within // `LATENCY_GRACE_PERIOD_BLOCKS` blocks and the channel is closed (i.e. we're at a // point where no further off-chain updates will be accepted). If we haven't seen the @@ -5384,7 +5386,7 @@ mod tests { let dummy_commitment_tx = HolderCommitmentTransaction::dummy(&mut htlcs); monitor.provide_latest_holder_commitment_tx(dummy_commitment_tx.clone(), - htlcs.into_iter().map(|(htlc, _)| (htlc, Some(dummy_sig), None)).collect()).unwrap(); + htlcs.into_iter().map(|(htlc, _)| (htlc, Some(dummy_sig), None)).collect()); monitor.provide_latest_counterparty_commitment_tx(Txid::from_byte_array(Sha256::hash(b"1").to_byte_array()), preimages_slice_to_htlc_outputs!(preimages[5..15]), 281474976710655, dummy_key, &logger); monitor.provide_latest_counterparty_commitment_tx(Txid::from_byte_array(Sha256::hash(b"2").to_byte_array()), @@ -5422,7 +5424,7 @@ mod tests { let mut htlcs = preimages_slice_to_htlcs!(preimages[0..5]); let dummy_commitment_tx = HolderCommitmentTransaction::dummy(&mut htlcs); monitor.provide_latest_holder_commitment_tx(dummy_commitment_tx.clone(), - htlcs.into_iter().map(|(htlc, _)| (htlc, Some(dummy_sig), None)).collect()).unwrap(); + htlcs.into_iter().map(|(htlc, _)| (htlc, Some(dummy_sig), None)).collect()); secret[0..32].clone_from_slice(&>::from_hex("2273e227a5b7449b6e70f1fb4652864038b1cbf9cd7c043a7d6456b7fc275ad8").unwrap()); monitor.provide_secret(281474976710653, secret.clone()).unwrap(); assert_eq!(monitor.inner.lock().unwrap().payment_preimages.len(), 12); @@ -5433,7 +5435,7 @@ mod tests { let mut htlcs = preimages_slice_to_htlcs!(preimages[0..3]); let dummy_commitment_tx = HolderCommitmentTransaction::dummy(&mut htlcs); monitor.provide_latest_holder_commitment_tx(dummy_commitment_tx, - htlcs.into_iter().map(|(htlc, _)| (htlc, Some(dummy_sig), None)).collect()).unwrap(); + htlcs.into_iter().map(|(htlc, _)| (htlc, Some(dummy_sig), None)).collect()); secret[0..32].clone_from_slice(&>::from_hex("27cddaa5624534cb6cb9d7da077cf2b22ab21e9b506fd4998a51d54502e99116").unwrap()); monitor.provide_secret(281474976710652, secret.clone()).unwrap(); assert_eq!(monitor.inner.lock().unwrap().payment_preimages.len(), 5); diff --git a/lightning/src/chain/package.rs b/lightning/src/chain/package.rs index 55214006d4c..bd6912c21f8 100644 --- a/lightning/src/chain/package.rs +++ b/lightning/src/chain/package.rs @@ -699,8 +699,13 @@ impl PackageSolvingData { match self { PackageSolvingData::RevokedOutput(RevokedOutput { .. }) => PackageMalleability::Malleable(AggregationCluster::Unpinnable), - PackageSolvingData::RevokedHTLCOutput(..) => - PackageMalleability::Malleable(AggregationCluster::Pinnable), + PackageSolvingData::RevokedHTLCOutput(RevokedHTLCOutput { htlc, .. }) => { + if htlc.offered { + PackageMalleability::Malleable(AggregationCluster::Unpinnable) + } else { + PackageMalleability::Malleable(AggregationCluster::Pinnable) + } + }, PackageSolvingData::CounterpartyOfferedHTLCOutput(..) => PackageMalleability::Malleable(AggregationCluster::Unpinnable), PackageSolvingData::CounterpartyReceivedHTLCOutput(..) => @@ -771,10 +776,12 @@ pub struct PackageTemplate { /// Block height at which our counterparty can potentially claim this output as well (assuming /// they have the keys or information required to do so). /// - /// This is used primarily by external consumers to decide when an output becomes "pinnable" - /// because the counterparty can potentially spend it. It is also used internally by - /// [`Self::get_height_timer`] to identify when an output must be claimed by, depending on the - /// type of output. + /// This is used primarily to decide when an output becomes "pinnable" because the counterparty + /// can potentially spend it. It is also used internally by [`Self::get_height_timer`] to + /// identify when an output must be claimed by, depending on the type of output. + /// + /// Note that for revoked counterparty HTLC outputs the value may be zero in some cases where + /// we upgraded from LDK 0.1 or prior. counterparty_spendable_height: u32, // Cache of package feerate committed at previous (re)broadcast. If bumping resources // (either claimed output value or external utxo), it will keep increasing until holder @@ -834,17 +841,17 @@ impl PackageTemplate { // Now check that we only merge packages if they are both unpinnable or both // pinnable. let self_pinnable = self_cluster == AggregationCluster::Pinnable || - self.counterparty_spendable_height() <= cur_height + COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE; + self.counterparty_spendable_height <= cur_height + COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE; let other_pinnable = other_cluster == AggregationCluster::Pinnable || - other.counterparty_spendable_height() <= cur_height + COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE; + other.counterparty_spendable_height <= cur_height + COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE; if self_pinnable && other_pinnable { return true; } let self_unpinnable = self_cluster == AggregationCluster::Unpinnable && - self.counterparty_spendable_height() > cur_height + COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE; + self.counterparty_spendable_height > cur_height + COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE; let other_unpinnable = other_cluster == AggregationCluster::Unpinnable && - other.counterparty_spendable_height() > cur_height + COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE; + other.counterparty_spendable_height > cur_height + COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE; if self_unpinnable && other_unpinnable { return true; } @@ -855,13 +862,6 @@ impl PackageTemplate { pub(crate) fn is_malleable(&self) -> bool { matches!(self.malleability, PackageMalleability::Malleable(..)) } - /// The height at which our counterparty may be able to spend this output. - /// - /// This is an important limit for aggregation as after this height our counterparty may be - /// able to pin transactions spending this output in the mempool. - pub(crate) fn counterparty_spendable_height(&self) -> u32 { - self.counterparty_spendable_height - } pub(crate) fn previous_feerate(&self) -> u64 { self.feerate_previous } @@ -1225,6 +1225,18 @@ impl Readable for PackageTemplate { (4, _height_original, option), // Written with a dummy value since 0.1 (6, height_timer, option), }); + for (_, input) in &inputs { + if let PackageSolvingData::RevokedHTLCOutput(RevokedHTLCOutput { htlc, .. }) = input { + // LDK versions through 0.1 set the wrong counterparty_spendable_height for + // non-offered revoked HTLCs (ie HTLCs we sent to our counterparty which they can + // claim with a preimage immediately). Here we detect this and reset the value to + // zero, as the value is unused except for merging decisions which doesn't care + // about any values below the current height. + if !htlc.offered && htlc.cltv_expiry == counterparty_spendable_height { + counterparty_spendable_height = 0; + } + } + } Ok(PackageTemplate { inputs, malleability, diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 6c8f3139c6a..9ede93c93d1 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -13364,8 +13364,8 @@ where // claim. // Note that a `ChannelMonitor` is created with `update_id` 0 and after we // provide it with a closure update its `update_id` will be at 1. - if !monitor.offchain_closed() || monitor.get_latest_update_id() > 1 { - should_queue_fc_update = !monitor.offchain_closed(); + if !monitor.no_further_updates_allowed() || monitor.get_latest_update_id() > 1 { + should_queue_fc_update = !monitor.no_further_updates_allowed(); let mut latest_update_id = monitor.get_latest_update_id(); if should_queue_fc_update { latest_update_id += 1; diff --git a/lightning/src/ln/functional_tests.rs b/lightning/src/ln/functional_tests.rs index b29ee99e077..bdb1621771f 100644 --- a/lightning/src/ln/functional_tests.rs +++ b/lightning/src/ln/functional_tests.rs @@ -15,7 +15,7 @@ use crate::chain; use crate::chain::{ChannelMonitorUpdateStatus, Confirm, Listen, Watch}; use crate::chain::chaininterface::LowerBoundedFeeEstimator; use crate::chain::channelmonitor; -use crate::chain::channelmonitor::{Balance, ChannelMonitorUpdateStep, CLTV_CLAIM_BUFFER, LATENCY_GRACE_PERIOD_BLOCKS, ANTI_REORG_DELAY}; +use crate::chain::channelmonitor::{Balance, ChannelMonitorUpdateStep, CLTV_CLAIM_BUFFER, LATENCY_GRACE_PERIOD_BLOCKS, ANTI_REORG_DELAY, COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE}; use crate::chain::transaction::OutPoint; use crate::sign::{ecdsa::EcdsaChannelSigner, EntropySource, OutputSpender, SignerProvider}; use crate::events::bump_transaction::WalletSource; @@ -2645,14 +2645,12 @@ fn test_justice_tx_htlc_timeout() { mine_transaction(&nodes[1], &revoked_local_txn[0]); { let mut node_txn = nodes[1].tx_broadcaster.txn_broadcasted.lock().unwrap(); - // The unpinnable, revoked to_self output, and the pinnable, revoked htlc output will - // be claimed in separate transactions. - assert_eq!(node_txn.len(), 2); - for tx in node_txn.iter() { - assert_eq!(tx.input.len(), 1); - check_spends!(tx, revoked_local_txn[0]); - } - assert_ne!(node_txn[0].input[0].previous_output, node_txn[1].input[0].previous_output); + // The revoked HTLC output is not pinnable for another `TEST_FINAL_CLTV` blocks, and is + // thus claimed in the same transaction with the revoked to_self output. + assert_eq!(node_txn.len(), 1); + assert_eq!(node_txn[0].input.len(), 2); + check_spends!(node_txn[0], revoked_local_txn[0]); + assert_ne!(node_txn[0].input[0].previous_output, node_txn[0].input[1].previous_output); node_txn.clear(); } check_added_monitors!(nodes[1], 1); @@ -2872,28 +2870,26 @@ fn claim_htlc_outputs() { assert!(nodes[1].node.get_and_clear_pending_events().is_empty()); let node_txn = nodes[1].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0); - assert_eq!(node_txn.len(), 2); // Two penalty transactions: - assert_eq!(node_txn[0].input.len(), 1); // Claims the unpinnable, revoked output. - assert_eq!(node_txn[1].input.len(), 2); // Claims both pinnable, revoked HTLC outputs separately. - check_spends!(node_txn[0], revoked_local_txn[0]); - check_spends!(node_txn[1], revoked_local_txn[0]); - assert_ne!(node_txn[0].input[0].previous_output, node_txn[1].input[0].previous_output); - assert_ne!(node_txn[0].input[0].previous_output, node_txn[1].input[1].previous_output); - assert_ne!(node_txn[1].input[0].previous_output, node_txn[1].input[1].previous_output); + assert_eq!(node_txn.len(), 2); // ChannelMonitor: penalty txn + + // The ChannelMonitor should claim the accepted HTLC output separately from the offered + // HTLC and to_self outputs. + let accepted_claim = node_txn.iter().filter(|tx| tx.input.len() == 1).next().unwrap(); + let offered_to_self_claim = node_txn.iter().filter(|tx| tx.input.len() == 2).next().unwrap(); + check_spends!(accepted_claim, revoked_local_txn[0]); + check_spends!(offered_to_self_claim, revoked_local_txn[0]); + assert_eq!(accepted_claim.input[0].witness.last().unwrap().len(), ACCEPTED_HTLC_SCRIPT_WEIGHT); let mut witness_lens = BTreeSet::new(); - witness_lens.insert(node_txn[0].input[0].witness.last().unwrap().len()); - witness_lens.insert(node_txn[1].input[0].witness.last().unwrap().len()); - witness_lens.insert(node_txn[1].input[1].witness.last().unwrap().len()); - assert_eq!(witness_lens.len(), 3); + witness_lens.insert(offered_to_self_claim.input[0].witness.last().unwrap().len()); + witness_lens.insert(offered_to_self_claim.input[1].witness.last().unwrap().len()); + assert_eq!(witness_lens.len(), 2); assert_eq!(*witness_lens.iter().skip(0).next().unwrap(), 77); // revoked to_local - assert_eq!(*witness_lens.iter().skip(1).next().unwrap(), OFFERED_HTLC_SCRIPT_WEIGHT); // revoked offered HTLC - assert_eq!(*witness_lens.iter().skip(2).next().unwrap(), ACCEPTED_HTLC_SCRIPT_WEIGHT); // revoked received HTLC + assert_eq!(*witness_lens.iter().skip(1).next().unwrap(), OFFERED_HTLC_SCRIPT_WEIGHT); - // Finally, mine the penalty transactions and check that we get an HTLC failure after + // Finally, mine the penalty transaction and check that we get an HTLC failure after // ANTI_REORG_DELAY confirmations. - mine_transaction(&nodes[1], &node_txn[0]); - mine_transaction(&nodes[1], &node_txn[1]); + mine_transaction(&nodes[1], accepted_claim); connect_blocks(&nodes[1], ANTI_REORG_DELAY - 1); expect_payment_failed!(nodes[1], payment_hash_2, false); } @@ -5056,8 +5052,7 @@ fn test_static_spendable_outputs_timeout_tx() { check_spends!(spend_txn[2], node_txn[0], commitment_tx[0]); // All outputs } -#[test] -fn test_static_spendable_outputs_justice_tx_revoked_commitment_tx() { +fn do_test_static_spendable_outputs_justice_tx_revoked_commitment_tx(split_tx: bool) { let chanmon_cfgs = create_chanmon_cfgs(2); let node_cfgs = create_node_cfgs(2, &chanmon_cfgs); let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); @@ -5073,20 +5068,28 @@ fn test_static_spendable_outputs_justice_tx_revoked_commitment_tx() { claim_payment(&nodes[0], &vec!(&nodes[1])[..], payment_preimage); + if split_tx { + connect_blocks(&nodes[1], TEST_FINAL_CLTV - COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE + 1); + } + mine_transaction(&nodes[1], &revoked_local_txn[0]); check_closed_broadcast!(nodes[1], true); check_added_monitors!(nodes[1], 1); check_closed_event!(nodes[1], 1, ClosureReason::CommitmentTxConfirmed, [nodes[0].node.get_our_node_id()], 100000); - // The unpinnable, revoked to_self output and the pinnable, revoked HTLC output will be claimed - // in separate transactions. + // If the HTLC expires in more than COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE blocks, we'll + // claim both the revoked and HTLC outputs in one transaction, otherwise we'll split them as we + // consider the HTLC output as pinnable and want to claim pinnable and unpinnable outputs + // separately. let node_txn = nodes[1].tx_broadcaster.txn_broadcasted.lock().unwrap().clone(); - assert_eq!(node_txn.len(), 2); + assert_eq!(node_txn.len(), if split_tx { 2 } else { 1 }); for tx in node_txn.iter() { - assert_eq!(tx.input.len(), 1); + assert_eq!(tx.input.len(), if split_tx { 1 } else { 2 }); check_spends!(tx, revoked_local_txn[0]); } - assert_ne!(node_txn[0].input[0].previous_output, node_txn[1].input[0].previous_output); + if split_tx { + assert_ne!(node_txn[0].input[0].previous_output, node_txn[1].input[0].previous_output); + } mine_transaction(&nodes[1], &node_txn[0]); connect_blocks(&nodes[1], ANTI_REORG_DELAY - 1); @@ -5096,6 +5099,12 @@ fn test_static_spendable_outputs_justice_tx_revoked_commitment_tx() { check_spends!(spend_txn[0], node_txn[0]); } +#[test] +fn test_static_spendable_outputs_justice_tx_revoked_commitment_tx() { + do_test_static_spendable_outputs_justice_tx_revoked_commitment_tx(true); + do_test_static_spendable_outputs_justice_tx_revoked_commitment_tx(false); +} + #[test] fn test_static_spendable_outputs_justice_tx_revoked_htlc_timeout_tx() { let mut chanmon_cfgs = create_chanmon_cfgs(2); @@ -5128,6 +5137,10 @@ fn test_static_spendable_outputs_justice_tx_revoked_htlc_timeout_tx() { check_spends!(revoked_htlc_txn[0], revoked_local_txn[0]); assert_ne!(revoked_htlc_txn[0].lock_time, LockTime::ZERO); // HTLC-Timeout + // In order to connect `revoked_htlc_txn[0]` we must first advance the chain by + // `TEST_FINAL_CLTV` blocks as otherwise the transaction is consensus-invalid due to its + // locktime. + connect_blocks(&nodes[1], TEST_FINAL_CLTV); // B will generate justice tx from A's revoked commitment/HTLC tx connect_block(&nodes[1], &create_dummy_block(nodes[1].best_block_hash(), 42, vec![revoked_local_txn[0].clone(), revoked_htlc_txn[0].clone()])); check_closed_broadcast!(nodes[1], true); diff --git a/lightning/src/ln/monitor_tests.rs b/lightning/src/ln/monitor_tests.rs index e2c76643348..7a20a79159a 100644 --- a/lightning/src/ln/monitor_tests.rs +++ b/lightning/src/ln/monitor_tests.rs @@ -10,7 +10,7 @@ //! Further functional tests which test blockchain reorganizations. use crate::sign::{ecdsa::EcdsaChannelSigner, OutputSpender, SpendableOutputDescriptor}; -use crate::chain::channelmonitor::{ANTI_REORG_DELAY, ARCHIVAL_DELAY_BLOCKS,LATENCY_GRACE_PERIOD_BLOCKS, Balance, BalanceSource, ChannelMonitorUpdateStep}; +use crate::chain::channelmonitor::{ANTI_REORG_DELAY, ARCHIVAL_DELAY_BLOCKS,LATENCY_GRACE_PERIOD_BLOCKS, COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE, Balance, BalanceSource, ChannelMonitorUpdateStep}; use crate::chain::transaction::OutPoint; use crate::chain::chaininterface::{ConfirmationTarget, LowerBoundedFeeEstimator, compute_feerate_sat_per_1000_weight}; use crate::events::bump_transaction::{BumpTransactionEvent, WalletSource}; @@ -1734,6 +1734,12 @@ fn do_test_revoked_counterparty_htlc_tx_balances(anchors: bool) { assert_eq!(revoked_htlc_success.lock_time, LockTime::ZERO); assert_ne!(revoked_htlc_timeout.lock_time, LockTime::ZERO); + // First connect blocks until the HTLC expires with + // `COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE` blocks, making us consider all the HTLCs + // pinnable claims, which the remainder of the test assumes. + connect_blocks(&nodes[0], TEST_FINAL_CLTV - COUNTERPARTY_CLAIMABLE_WITHIN_BLOCKS_PINNABLE); + expect_pending_htlcs_forwardable_and_htlc_handling_failed_ignore!(&nodes[0], + [HTLCDestination::FailedPayment { payment_hash: failed_payment_hash }]); // A will generate justice tx from B's revoked commitment/HTLC tx mine_transaction(&nodes[0], &revoked_local_txn[0]); check_closed_broadcast!(nodes[0], true); @@ -1846,8 +1852,6 @@ fn do_test_revoked_counterparty_htlc_tx_balances(anchors: bool) { sorted_vec(nodes[0].chain_monitor.chain_monitor.get_monitor(funding_outpoint).unwrap().get_claimable_balances())); connect_blocks(&nodes[0], revoked_htlc_timeout.lock_time.to_consensus_u32() - nodes[0].best_block_info().1); - expect_pending_htlcs_forwardable_and_htlc_handling_failed_ignore!(&nodes[0], - [HTLCDestination::FailedPayment { payment_hash: failed_payment_hash }]); // As time goes on A may split its revocation claim transaction into multiple. let as_fewer_input_rbf = nodes[0].tx_broadcaster.txn_broadcasted.lock().unwrap().split_off(0); for tx in as_fewer_input_rbf.iter() { diff --git a/lightning/src/ln/msgs.rs b/lightning/src/ln/msgs.rs index 659ec65f6cf..0c400fc36e0 100644 --- a/lightning/src/ln/msgs.rs +++ b/lightning/src/ln/msgs.rs @@ -1578,6 +1578,8 @@ pub trait ChannelMessageHandler : MessageSendEventsProvider { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>; /// Handle an incoming `channel_reestablish` message from the given peer. fn handle_channel_reestablish(&self, their_node_id: PublicKey, msg: &ChannelReestablish); @@ -1707,6 +1709,8 @@ pub trait OnionMessageHandler { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, init: &Init, inbound: bool) -> Result<(), ()>; /// Indicates a connection to the peer failed/an existing connection was lost. Allows handlers to diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index dbce9ca0498..9b649408423 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -88,6 +88,8 @@ pub trait CustomMessageHandler: wire::CustomMessageReader { /// May return an `Err(())` if the features the peer supports are not sufficient to communicate /// with us. Implementors should be somewhat conservative about doing so, however, as other /// message handlers may still wish to communicate with this peer. + /// + /// [`Self::peer_disconnected`] will not be called if `Err(())` is returned. fn peer_connected(&self, their_node_id: PublicKey, msg: &Init, inbound: bool) -> Result<(), ()>; /// Gets the node feature flags which this handler itself supports. All available handlers are @@ -1717,10 +1719,13 @@ impl= 253. - $self.bytes.len() - + $self.experimental_bytes.len() - + if $self.contents.is_for_offer() { 0 } else { 2 }, - $self.bytes.capacity(), - ); $self.bytes.extend_from_slice(&$self.experimental_bytes); Ok(Bolt12Invoice { @@ -965,13 +939,6 @@ impl Hash for Bolt12Invoice { } impl InvoiceContents { - fn is_for_offer(&self) -> bool { - match self { - InvoiceContents::ForOffer { .. } => true, - InvoiceContents::ForRefund { .. } => false, - } - } - /// Whether the original offer or refund has expired. #[cfg(feature = "std")] fn is_offer_or_refund_expired(&self) -> bool { @@ -1362,7 +1329,11 @@ pub(super) const EXPERIMENTAL_INVOICE_TYPES: core::ops::RangeFrom = 3_000_0 #[cfg(not(test))] tlv_stream!( - ExperimentalInvoiceTlvStream, ExperimentalInvoiceTlvStreamRef, EXPERIMENTAL_INVOICE_TYPES, {} + ExperimentalInvoiceTlvStream, ExperimentalInvoiceTlvStreamRef, EXPERIMENTAL_INVOICE_TYPES, { + // When adding experimental TLVs, update EXPERIMENTAL_TLV_ALLOCATION_SIZE accordingly in + // both UnsignedBolt12Invoice:new and UnsignedStaticInvoice::new to avoid unnecessary + // allocations. + } ); #[cfg(test)] @@ -2880,9 +2851,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice.bytes.reserve_exact( - unsigned_invoice.bytes.capacity() - unsigned_invoice.bytes.len() + unknown_bytes.len(), - ); unsigned_invoice.bytes.extend_from_slice(&unknown_bytes); unsigned_invoice.tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &unsigned_invoice.bytes); @@ -2917,9 +2885,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice.bytes.reserve_exact( - unsigned_invoice.bytes.capacity() - unsigned_invoice.bytes.len() + unknown_bytes.len(), - ); unsigned_invoice.bytes.extend_from_slice(&unknown_bytes); unsigned_invoice.tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &unsigned_invoice.bytes); @@ -2982,9 +2947,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice.bytes.reserve_exact( - unsigned_invoice.bytes.capacity() - unsigned_invoice.bytes.len() + unknown_bytes.len(), - ); unsigned_invoice.experimental_bytes.extend_from_slice(&unknown_bytes); let tlv_stream = TlvStream::new(&unsigned_invoice.bytes) @@ -3021,9 +2983,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice.bytes.reserve_exact( - unsigned_invoice.bytes.capacity() - unsigned_invoice.bytes.len() + unknown_bytes.len(), - ); unsigned_invoice.experimental_bytes.extend_from_slice(&unknown_bytes); let tlv_stream = TlvStream::new(&unsigned_invoice.bytes) diff --git a/lightning/src/offers/invoice_request.rs b/lightning/src/offers/invoice_request.rs index 957884f69d0..8d7d25cf2b5 100644 --- a/lightning/src/offers/invoice_request.rs +++ b/lightning/src/offers/invoice_request.rs @@ -77,7 +77,7 @@ use crate::ln::channelmanager::PaymentId; use crate::types::features::InvoiceRequestFeatures; use crate::ln::inbound_payment::{ExpandedKey, IV_LEN}; use crate::ln::msgs::DecodeError; -use crate::offers::merkle::{SignError, SignFn, SignatureTlvStream, SignatureTlvStreamRef, TaggedHash, TlvStream, self, SIGNATURE_TLV_RECORD_SIZE}; +use crate::offers::merkle::{SignError, SignFn, SignatureTlvStream, SignatureTlvStreamRef, TaggedHash, TlvStream, self}; use crate::offers::nonce::Nonce; use crate::offers::offer::{Amount, EXPERIMENTAL_OFFER_TYPES, ExperimentalOfferTlvStream, ExperimentalOfferTlvStreamRef, OFFER_TYPES, Offer, OfferContents, OfferId, OfferTlvStream, OfferTlvStreamRef}; use crate::offers::parse::{Bolt12ParseError, ParsedMessage, Bolt12SemanticError}; @@ -473,17 +473,8 @@ impl UnsignedInvoiceRequest { _experimental_offer_tlv_stream, experimental_invoice_request_tlv_stream, ) = contents.as_tlv_stream(); - // Allocate enough space for the invoice_request, which will include: - // - all TLV records from `offer.bytes`, - // - all invoice_request-specific TLV records, and - // - a signature TLV record once the invoice_request is signed. - let mut bytes = Vec::with_capacity( - offer.bytes.len() - + payer_tlv_stream.serialized_length() - + invoice_request_tlv_stream.serialized_length() - + SIGNATURE_TLV_RECORD_SIZE - + experimental_invoice_request_tlv_stream.serialized_length(), - ); + const INVOICE_REQUEST_ALLOCATION_SIZE: usize = 512; + let mut bytes = Vec::with_capacity(INVOICE_REQUEST_ALLOCATION_SIZE); payer_tlv_stream.write(&mut bytes).unwrap(); @@ -495,23 +486,16 @@ impl UnsignedInvoiceRequest { invoice_request_tlv_stream.write(&mut bytes).unwrap(); - let mut experimental_tlv_stream = TlvStream::new(remaining_bytes) - .range(EXPERIMENTAL_OFFER_TYPES) - .peekable(); - let mut experimental_bytes = Vec::with_capacity( - remaining_bytes.len() - - experimental_tlv_stream - .peek() - .map_or(remaining_bytes.len(), |first_record| first_record.start) - + experimental_invoice_request_tlv_stream.serialized_length(), - ); + const EXPERIMENTAL_TLV_ALLOCATION_SIZE: usize = 0; + let mut experimental_bytes = Vec::with_capacity(EXPERIMENTAL_TLV_ALLOCATION_SIZE); + let experimental_tlv_stream = TlvStream::new(remaining_bytes) + .range(EXPERIMENTAL_OFFER_TYPES); for record in experimental_tlv_stream { record.write(&mut experimental_bytes).unwrap(); } experimental_invoice_request_tlv_stream.write(&mut experimental_bytes).unwrap(); - debug_assert_eq!(experimental_bytes.len(), experimental_bytes.capacity()); let tlv_stream = TlvStream::new(&bytes).chain(TlvStream::new(&experimental_bytes)); let tagged_hash = TaggedHash::from_tlv_stream(SIGNATURE_TAG, tlv_stream); @@ -544,12 +528,6 @@ macro_rules! unsigned_invoice_request_sign_method { ( signature_tlv_stream.write(&mut $self.bytes).unwrap(); // Append the experimental bytes after the signature. - debug_assert_eq!( - // The two-byte overallocation results from SIGNATURE_TLV_RECORD_SIZE accommodating TLV - // records with types >= 253. - $self.bytes.len() + $self.experimental_bytes.len() + 2, - $self.bytes.capacity(), - ); $self.bytes.extend_from_slice(&$self.experimental_bytes); Ok(InvoiceRequest { @@ -1127,7 +1105,10 @@ pub(super) const EXPERIMENTAL_INVOICE_REQUEST_TYPES: core::ops::Range = #[cfg(not(test))] tlv_stream!( ExperimentalInvoiceRequestTlvStream, ExperimentalInvoiceRequestTlvStreamRef, - EXPERIMENTAL_INVOICE_REQUEST_TYPES, {} + EXPERIMENTAL_INVOICE_REQUEST_TYPES, { + // When adding experimental TLVs, update EXPERIMENTAL_TLV_ALLOCATION_SIZE accordingly in + // UnsignedInvoiceRequest::new to avoid unnecessary allocations. + } ); #[cfg(test)] @@ -2422,11 +2403,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice_request.bytes.reserve_exact( - unsigned_invoice_request.bytes.capacity() - - unsigned_invoice_request.bytes.len() - + unknown_bytes.len(), - ); unsigned_invoice_request.bytes.extend_from_slice(&unknown_bytes); unsigned_invoice_request.tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &unsigned_invoice_request.bytes); @@ -2460,11 +2436,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice_request.bytes.reserve_exact( - unsigned_invoice_request.bytes.capacity() - - unsigned_invoice_request.bytes.len() - + unknown_bytes.len(), - ); unsigned_invoice_request.bytes.extend_from_slice(&unknown_bytes); unsigned_invoice_request.tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &unsigned_invoice_request.bytes); @@ -2508,11 +2479,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice_request.bytes.reserve_exact( - unsigned_invoice_request.bytes.capacity() - - unsigned_invoice_request.bytes.len() - + unknown_bytes.len(), - ); unsigned_invoice_request.experimental_bytes.extend_from_slice(&unknown_bytes); let tlv_stream = TlvStream::new(&unsigned_invoice_request.bytes) @@ -2549,11 +2515,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice_request.bytes.reserve_exact( - unsigned_invoice_request.bytes.capacity() - - unsigned_invoice_request.bytes.len() - + unknown_bytes.len(), - ); unsigned_invoice_request.experimental_bytes.extend_from_slice(&unknown_bytes); let tlv_stream = TlvStream::new(&unsigned_invoice_request.bytes) diff --git a/lightning/src/offers/merkle.rs b/lightning/src/offers/merkle.rs index 8c3eaaed24d..db4f4a41094 100644 --- a/lightning/src/offers/merkle.rs +++ b/lightning/src/offers/merkle.rs @@ -11,7 +11,6 @@ use bitcoin::hashes::{Hash, HashEngine, sha256}; use bitcoin::secp256k1::{Message, PublicKey, Secp256k1, self}; -use bitcoin::secp256k1::constants::SCHNORR_SIGNATURE_SIZE; use bitcoin::secp256k1::schnorr::Signature; use crate::io; use crate::util::ser::{BigSize, Readable, Writeable, Writer}; @@ -26,10 +25,6 @@ tlv_stream!(SignatureTlvStream, SignatureTlvStreamRef<'a>, SIGNATURE_TYPES, { (240, signature: Signature), }); -/// Size of a TLV record in `SIGNATURE_TYPES` when the type is 1000. TLV types are encoded using -/// BigSize, so a TLV record with type 240 will use two less bytes. -pub(super) const SIGNATURE_TLV_RECORD_SIZE: usize = 3 + 1 + SCHNORR_SIGNATURE_SIZE; - /// A hash for use in a specific context by tweaking with a context-dependent tag as per [BIP 340] /// and computed over the merkle root of a TLV stream to sign as defined in [BOLT 12]. /// @@ -253,7 +248,6 @@ pub(super) struct TlvRecord<'a> { type_bytes: &'a [u8], // The entire TLV record. pub(super) record_bytes: &'a [u8], - pub(super) start: usize, pub(super) end: usize, } @@ -278,7 +272,7 @@ impl<'a> Iterator for TlvStream<'a> { self.data.set_position(end); Some(TlvRecord { - r#type, type_bytes, record_bytes, start: start as usize, end: end as usize, + r#type, type_bytes, record_bytes, end: end as usize, }) } else { None diff --git a/lightning/src/offers/offer.rs b/lightning/src/offers/offer.rs index 613f9accd47..0cd1cece6b7 100644 --- a/lightning/src/offers/offer.rs +++ b/lightning/src/offers/offer.rs @@ -438,7 +438,8 @@ macro_rules! offer_builder_methods { ( } } - let mut bytes = Vec::new(); + const OFFER_ALLOCATION_SIZE: usize = 512; + let mut bytes = Vec::with_capacity(OFFER_ALLOCATION_SIZE); $self.offer.write(&mut bytes).unwrap(); let id = OfferId::from_valid_offer_tlv_stream(&bytes); diff --git a/lightning/src/offers/refund.rs b/lightning/src/offers/refund.rs index a68d0eb658e..e562fb3f901 100644 --- a/lightning/src/offers/refund.rs +++ b/lightning/src/offers/refund.rs @@ -338,7 +338,8 @@ macro_rules! refund_builder_methods { ( $self.refund.payer.0 = metadata; } - let mut bytes = Vec::new(); + const REFUND_ALLOCATION_SIZE: usize = 512; + let mut bytes = Vec::with_capacity(REFUND_ALLOCATION_SIZE); $self.refund.write(&mut bytes).unwrap(); Ok(Refund { diff --git a/lightning/src/offers/static_invoice.rs b/lightning/src/offers/static_invoice.rs index 411ba3ff272..4360582a14c 100644 --- a/lightning/src/offers/static_invoice.rs +++ b/lightning/src/offers/static_invoice.rs @@ -25,7 +25,6 @@ use crate::offers::invoice_macros::{invoice_accessors_common, invoice_builder_me use crate::offers::invoice_request::InvoiceRequest; use crate::offers::merkle::{ self, SignError, SignFn, SignatureTlvStream, SignatureTlvStreamRef, TaggedHash, TlvStream, - SIGNATURE_TLV_RECORD_SIZE, }; use crate::offers::nonce::Nonce; use crate::offers::offer::{ @@ -288,16 +287,8 @@ impl UnsignedStaticInvoice { fn new(offer_bytes: &Vec, contents: InvoiceContents) -> Self { let (_, invoice_tlv_stream, _, experimental_invoice_tlv_stream) = contents.as_tlv_stream(); - // Allocate enough space for the invoice, which will include: - // - all TLV records from `offer_bytes`, - // - all invoice-specific TLV records, and - // - a signature TLV record once the invoice is signed. - let mut bytes = Vec::with_capacity( - offer_bytes.len() - + invoice_tlv_stream.serialized_length() - + SIGNATURE_TLV_RECORD_SIZE - + experimental_invoice_tlv_stream.serialized_length(), - ); + const INVOICE_ALLOCATION_SIZE: usize = 1024; + let mut bytes = Vec::with_capacity(INVOICE_ALLOCATION_SIZE); // Use the offer bytes instead of the offer TLV stream as the latter may have contained // unknown TLV records, which are not stored in `InvoiceContents`. @@ -309,22 +300,16 @@ impl UnsignedStaticInvoice { invoice_tlv_stream.write(&mut bytes).unwrap(); - let mut experimental_tlv_stream = - TlvStream::new(remaining_bytes).range(EXPERIMENTAL_OFFER_TYPES).peekable(); - let mut experimental_bytes = Vec::with_capacity( - remaining_bytes.len() - - experimental_tlv_stream - .peek() - .map_or(remaining_bytes.len(), |first_record| first_record.start) - + experimental_invoice_tlv_stream.serialized_length(), - ); + const EXPERIMENTAL_TLV_ALLOCATION_SIZE: usize = 0; + let mut experimental_bytes = Vec::with_capacity(EXPERIMENTAL_TLV_ALLOCATION_SIZE); + let experimental_tlv_stream = + TlvStream::new(remaining_bytes).range(EXPERIMENTAL_OFFER_TYPES); for record in experimental_tlv_stream { record.write(&mut experimental_bytes).unwrap(); } experimental_invoice_tlv_stream.write(&mut experimental_bytes).unwrap(); - debug_assert_eq!(experimental_bytes.len(), experimental_bytes.capacity()); let tlv_stream = TlvStream::new(&bytes).chain(TlvStream::new(&experimental_bytes)); let tagged_hash = TaggedHash::from_tlv_stream(SIGNATURE_TAG, tlv_stream); @@ -344,12 +329,6 @@ impl UnsignedStaticInvoice { signature_tlv_stream.write(&mut self.bytes).unwrap(); // Append the experimental bytes after the signature. - debug_assert_eq!( - // The two-byte overallocation results from SIGNATURE_TLV_RECORD_SIZE accommodating TLV - // records with types >= 253. - self.bytes.len() + self.experimental_bytes.len() + 2, - self.bytes.capacity(), - ); self.bytes.extend_from_slice(&self.experimental_bytes); Ok(StaticInvoice { bytes: self.bytes, contents: self.contents, signature }) @@ -1392,9 +1371,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice.bytes.reserve_exact( - unsigned_invoice.bytes.capacity() - unsigned_invoice.bytes.len() + unknown_bytes.len(), - ); unsigned_invoice.bytes.extend_from_slice(&unknown_bytes); unsigned_invoice.tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &unsigned_invoice.bytes); @@ -1434,9 +1410,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice.bytes.reserve_exact( - unsigned_invoice.bytes.capacity() - unsigned_invoice.bytes.len() + unknown_bytes.len(), - ); unsigned_invoice.bytes.extend_from_slice(&unknown_bytes); unsigned_invoice.tagged_hash = TaggedHash::from_valid_tlv_stream_bytes(SIGNATURE_TAG, &unsigned_invoice.bytes); @@ -1511,9 +1484,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice.bytes.reserve_exact( - unsigned_invoice.bytes.capacity() - unsigned_invoice.bytes.len() + unknown_bytes.len(), - ); unsigned_invoice.experimental_bytes.extend_from_slice(&unknown_bytes); let tlv_stream = TlvStream::new(&unsigned_invoice.bytes) @@ -1555,9 +1525,6 @@ mod tests { BigSize(32).write(&mut unknown_bytes).unwrap(); [42u8; 32].write(&mut unknown_bytes).unwrap(); - unsigned_invoice.bytes.reserve_exact( - unsigned_invoice.bytes.capacity() - unsigned_invoice.bytes.len() + unknown_bytes.len(), - ); unsigned_invoice.experimental_bytes.extend_from_slice(&unknown_bytes); let tlv_stream = TlvStream::new(&unsigned_invoice.bytes) diff --git a/lightning/src/routing/gossip.rs b/lightning/src/routing/gossip.rs index c552005d9ca..fc097d5f915 100644 --- a/lightning/src/routing/gossip.rs +++ b/lightning/src/routing/gossip.rs @@ -1660,8 +1660,7 @@ where let chain_hash: ChainHash = Readable::read(reader)?; let channels_count: u64 = Readable::read(reader)?; - // In Nov, 2023 there were about 15,000 nodes; we cap allocations to 1.5x that. - let mut channels = IndexedMap::with_capacity(cmp::min(channels_count as usize, 22500)); + let mut channels = IndexedMap::with_capacity(CHAN_COUNT_ESTIMATE); for _ in 0..channels_count { let chan_id: u64 = Readable::read(reader)?; let chan_info: ChannelInfo = Readable::read(reader)?; @@ -1673,8 +1672,7 @@ where if nodes_count > u32::max_value() as u64 / 2 { return Err(DecodeError::InvalidValue); } - // In Nov, 2023 there were about 69K channels; we cap allocations to 1.5x that. - let mut nodes = IndexedMap::with_capacity(cmp::min(nodes_count as usize, 103500)); + let mut nodes = IndexedMap::with_capacity(NODE_COUNT_ESTIMATE); for i in 0..nodes_count { let node_id = Readable::read(reader)?; let mut node_info: NodeInfo = Readable::read(reader)?; @@ -1750,6 +1748,15 @@ where } } +// In Jan, 2025 there were about 49K channels. +// We over-allocate by a bit because 20% more is better than the double we get if we're slightly +// too low +const CHAN_COUNT_ESTIMATE: usize = 60_000; +// In Jan, 2025 there were about 15K nodes +// We over-allocate by a bit because 33% more is better than the double we get if we're slightly +// too low +const NODE_COUNT_ESTIMATE: usize = 20_000; + impl NetworkGraph where L::Target: Logger, @@ -1760,8 +1767,8 @@ where secp_ctx: Secp256k1::verification_only(), chain_hash: ChainHash::using_genesis_block(network), logger, - channels: RwLock::new(IndexedMap::new()), - nodes: RwLock::new(IndexedMap::new()), + channels: RwLock::new(IndexedMap::with_capacity(CHAN_COUNT_ESTIMATE)), + nodes: RwLock::new(IndexedMap::with_capacity(NODE_COUNT_ESTIMATE)), next_node_counter: AtomicUsize::new(0), removed_node_counters: Mutex::new(Vec::new()), last_rapid_gossip_sync_timestamp: Mutex::new(None), @@ -2537,7 +2544,7 @@ where } }; - let node_pubkey; + let mut node_pubkey = None; { let channels = self.channels.read().unwrap(); match channels.get(&msg.short_channel_id) { @@ -2556,16 +2563,31 @@ where } else { channel.node_one.as_slice() }; - node_pubkey = PublicKey::from_slice(node_id).map_err(|_| LightningError { - err: "Couldn't parse source node pubkey".to_owned(), - action: ErrorAction::IgnoreAndLog(Level::Debug), - })?; + if sig.is_some() { + // PublicKey parsing isn't entirely trivial as it requires that we check + // that the provided point is on the curve. Thus, if we don't have a + // signature to verify, we want to skip the parsing step entirely. + // This represents a substantial speedup in applying RGS snapshots. + node_pubkey = + Some(PublicKey::from_slice(node_id).map_err(|_| LightningError { + err: "Couldn't parse source node pubkey".to_owned(), + action: ErrorAction::IgnoreAndLog(Level::Debug), + })?); + } }, } } - let msg_hash = hash_to_message!(&message_sha256d_hash(&msg)[..]); if let Some(sig) = sig { + let msg_hash = hash_to_message!(&message_sha256d_hash(&msg)[..]); + let node_pubkey = if let Some(pubkey) = node_pubkey { + pubkey + } else { + debug_assert!(false, "node_pubkey should have been decoded above"); + let err = "node_pubkey wasn't decoded but we need it to check a sig".to_owned(); + let action = ErrorAction::IgnoreAndLog(Level::Error); + return Err(LightningError { err, action }); + }; secp_verify_sig!(self.secp_ctx, &msg_hash, &sig, &node_pubkey, "channel_update"); } diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 04f55837267..079b83563c3 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -1164,6 +1164,7 @@ impl_writeable_tlv_based!(RouteHintHop, { #[repr(align(64))] // Force the size to 64 bytes struct RouteGraphNode { node_id: NodeId, + node_counter: u32, score: u64, // The maximum value a yet-to-be-constructed payment path might flow through this node. // This value is upper-bounded by us by: @@ -1178,7 +1179,10 @@ struct RouteGraphNode { impl cmp::Ord for RouteGraphNode { fn cmp(&self, other: &RouteGraphNode) -> cmp::Ordering { - other.score.cmp(&self.score).then_with(|| other.node_id.cmp(&self.node_id)) + other.score.cmp(&self.score) + .then_with(|| self.value_contribution_msat.cmp(&other.value_contribution_msat)) + .then_with(|| other.path_length_to_node.cmp(&self.path_length_to_node)) + .then_with(|| other.node_counter.cmp(&self.node_counter)) } } @@ -1780,6 +1784,12 @@ struct PathBuildingHop<'a> { /// decrease as well. Thus, we have to explicitly track which nodes have been processed and /// avoid processing them again. was_processed: bool, + /// If we've already processed a channel backwards from a target node, we shouldn't update our + /// selected best path from that node to the destination. This should never happen, but with + /// multiple codepaths processing channels we've had issues here in the past, so in debug-mode + /// we track it and assert on it when processing a node. + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] + best_path_from_hop_selected: bool, /// When processing a node as the next best-score candidate, we want to quickly check if it is /// a direct counterparty of ours, using our local channel information immediately if we can. /// @@ -1788,6 +1798,8 @@ struct PathBuildingHop<'a> { /// updated after being initialized - it is set at the start of a route-finding pass and only /// read thereafter. is_first_hop_target: bool, + /// Identical to the above, but for handling unblinded last-hops rather than first-hops. + is_last_hop_target: bool, /// Used to compare channels when choosing the for routing. /// Includes paying for the use of a hop and the following hops, as well as /// an estimated cost of reaching this hop. @@ -1808,11 +1820,7 @@ struct PathBuildingHop<'a> { /// The value will be actually deducted from the counterparty balance on the previous link. hop_use_fee_msat: u64, - #[cfg(all(not(ldk_bench), any(test, fuzzing)))] - // In tests, we apply further sanity checks on cases where we skip nodes we already processed - // to ensure it is specifically in cases where the fee has gone down because of a decrease in - // value_contribution_msat, which requires tracking it here. See comments below where it is - // used for more info. + /// The quantity of funds we're willing to route over this channel value_contribution_msat: u64, } @@ -1827,15 +1835,14 @@ impl<'a> core::fmt::Debug for PathBuildingHop<'a> { .field("target_node_id", &self.candidate.target()) .field("short_channel_id", &self.candidate.short_channel_id()) .field("is_first_hop_target", &self.is_first_hop_target) + .field("is_last_hop_target", &self.is_last_hop_target) .field("total_fee_msat", &self.total_fee_msat) .field("next_hops_fee_msat", &self.next_hops_fee_msat) .field("hop_use_fee_msat", &self.hop_use_fee_msat) .field("total_fee_msat - (next_hops_fee_msat + hop_use_fee_msat)", &(&self.total_fee_msat.saturating_sub(self.next_hops_fee_msat).saturating_sub(self.hop_use_fee_msat))) .field("path_penalty_msat", &self.path_penalty_msat) .field("path_htlc_minimum_msat", &self.path_htlc_minimum_msat) - .field("cltv_expiry_delta", &self.candidate.cltv_expiry_delta()); - #[cfg(all(not(ldk_bench), any(test, fuzzing)))] - let debug_struct = debug_struct + .field("cltv_expiry_delta", &self.candidate.cltv_expiry_delta()) .field("value_contribution_msat", &self.value_contribution_msat); debug_struct.finish() } @@ -2266,8 +2273,10 @@ where L::Target: Logger { // Step (1). Prepare first and last hop targets. // - // First cache all our direct channels so that we can insert them in the heap at startup. - // Then process any blinded routes, resolving their introduction node and caching it. + // For unblinded first- and last-hop channels, cache them in maps so that we can detect them as + // we walk the graph and incorporate them into our candidate set. + // For blinded last-hop paths, look up their introduction point and cache the node counters + // identifying them. let mut first_hop_targets: HashMap<_, (Vec<&ChannelDetails>, u32)> = hash_map_with_capacity(if first_hops.is_some() { first_hops.as_ref().unwrap().len() } else { 0 }); if let Some(hops) = first_hops { @@ -2299,6 +2308,56 @@ where L::Target: Logger { &payment_params, &node_counters, network_graph, &logger, our_node_id, &first_hop_targets, )?; + let mut last_hop_candidates = + hash_map_with_capacity(payment_params.payee.unblinded_route_hints().len()); + for route in payment_params.payee.unblinded_route_hints().iter() + .filter(|route| !route.0.is_empty()) + { + let hop_iter = route.0.iter().rev(); + let prev_hop_iter = core::iter::once(&maybe_dummy_payee_pk).chain( + route.0.iter().skip(1).rev().map(|hop| &hop.src_node_id)); + + for (hop, prev_hop_id) in hop_iter.zip(prev_hop_iter) { + let (target, private_target_node_counter) = + node_counters.private_node_counter_from_pubkey(&prev_hop_id) + .ok_or_else(|| { + debug_assert!(false); + LightningError { err: "We should always have private target node counters available".to_owned(), action: ErrorAction::IgnoreError } + })?; + let (_src_id, private_source_node_counter) = + node_counters.private_node_counter_from_pubkey(&hop.src_node_id) + .ok_or_else(|| { + debug_assert!(false); + LightningError { err: "We should always have private source node counters available".to_owned(), action: ErrorAction::IgnoreError } + })?; + + if let Some((first_channels, _)) = first_hop_targets.get(target) { + let matches_an_scid = |d: &&ChannelDetails| + d.outbound_scid_alias == Some(hop.short_channel_id) || d.short_channel_id == Some(hop.short_channel_id); + if first_channels.iter().any(matches_an_scid) { + log_trace!(logger, "Ignoring route hint with SCID {} (and any previous) due to it being a direct channel of ours.", + hop.short_channel_id); + break; + } + } + + let candidate = network_channels + .get(&hop.short_channel_id) + .and_then(|channel| channel.as_directed_to(target)) + .map(|(info, _)| CandidateRouteHop::PublicHop(PublicHopCandidate { + info, + short_channel_id: hop.short_channel_id, + })) + .unwrap_or_else(|| CandidateRouteHop::PrivateHop(PrivateHopCandidate { + hint: hop, target_node_id: target, + source_node_counter: *private_source_node_counter, + target_node_counter: *private_target_node_counter, + })); + + last_hop_candidates.entry(private_target_node_counter).or_insert_with(Vec::new).push(candidate); + } + } + // The main heap containing all candidate next-hops sorted by their score (max(fee, // htlc_minimum)). Ideally this would be a heap which allowed cheap score reduction instead of // adding duplicate entries when we find a better path to a given node. @@ -2382,6 +2441,19 @@ where L::Target: Logger { // We "return" whether we updated the path at the end, and how much we can route via // this channel, via this: let mut hop_contribution_amt_msat = None; + + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] + if let Some(counter) = $candidate.target_node_counter() { + // Once we are adding paths backwards from a given target, we've selected the best + // path from that target to the destination and it should no longer change. We thus + // set the best-path selected flag and check that it doesn't change below. + if let Some(node) = &mut dist[counter as usize] { + node.best_path_from_hop_selected = true; + } else if counter != payee_node_counter { + panic!("No dist entry for target node counter {}", counter); + } + } + // Channels to self should not be used. This is more of belt-and-suspenders, because in // practice these cases should be caught earlier: // - for regular channels at channel announcement (TODO) @@ -2514,8 +2586,15 @@ where L::Target: Logger { let curr_min = cmp::max( $next_hops_path_htlc_minimum_msat, htlc_minimum_msat ); - let candidate_fees = $candidate.fees(); let src_node_counter = $candidate.src_node_counter(); + let mut candidate_fees = $candidate.fees(); + if src_node_counter == payer_node_counter { + // We do not charge ourselves a fee to use our own channels. + candidate_fees = RoutingFees { + proportional_millionths: 0, + base_msat: 0, + }; + } let path_htlc_minimum_msat = compute_fees_saturating(curr_min, candidate_fees) .saturating_add(curr_min); @@ -2538,7 +2617,9 @@ where L::Target: Logger { path_penalty_msat: u64::max_value(), was_processed: false, is_first_hop_target: false, + is_last_hop_target: false, #[cfg(all(not(ldk_bench), any(test, fuzzing)))] + best_path_from_hop_selected: false, value_contribution_msat, }); dist_entry.as_mut().unwrap() @@ -2614,10 +2695,19 @@ where L::Target: Logger { .saturating_add(old_entry.path_penalty_msat); let new_cost = cmp::max(total_fee_msat, path_htlc_minimum_msat) .saturating_add(path_penalty_msat); + let should_replace = + new_cost < old_cost + || (new_cost == old_cost && old_entry.value_contribution_msat < value_contribution_msat); + + if !old_entry.was_processed && should_replace { + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] + { + assert!(!old_entry.best_path_from_hop_selected); + } - if !old_entry.was_processed && new_cost < old_cost { let new_graph_node = RouteGraphNode { node_id: src_node_id, + node_counter: src_node_counter, score: cmp::max(total_fee_msat, path_htlc_minimum_msat).saturating_add(path_penalty_msat), total_cltv_delta: hop_total_cltv_delta, value_contribution_msat, @@ -2631,10 +2721,7 @@ where L::Target: Logger { old_entry.fee_msat = 0; // This value will be later filled with hop_use_fee_msat of the following channel old_entry.path_htlc_minimum_msat = path_htlc_minimum_msat; old_entry.path_penalty_msat = path_penalty_msat; - #[cfg(all(not(ldk_bench), any(test, fuzzing)))] - { - old_entry.value_contribution_msat = value_contribution_msat; - } + old_entry.value_contribution_msat = value_contribution_msat; hop_contribution_amt_msat = Some(value_contribution_msat); } else if old_entry.was_processed && new_cost < old_cost { #[cfg(all(not(ldk_bench), any(test, fuzzing)))] @@ -2696,19 +2783,20 @@ where L::Target: Logger { // meaning how much will be paid in fees after this node (to the best of our knowledge). // This data can later be helpful to optimize routing (pay lower fees). macro_rules! add_entries_to_cheapest_to_target_node { - ( $node: expr, $node_id: expr, $next_hops_value_contribution: expr, + ( $node: expr, $node_counter: expr, $node_id: expr, $next_hops_value_contribution: expr, $next_hops_cltv_delta: expr, $next_hops_path_length: expr ) => { let fee_to_target_msat; let next_hops_path_htlc_minimum_msat; let next_hops_path_penalty_msat; - let is_first_hop_target; - let skip_node = if let Some(elem) = &mut dist[$node.node_counter as usize] { + let (is_first_hop_target, is_last_hop_target); + let skip_node = if let Some(elem) = &mut dist[$node_counter as usize] { let was_processed = elem.was_processed; elem.was_processed = true; fee_to_target_msat = elem.total_fee_msat; next_hops_path_htlc_minimum_msat = elem.path_htlc_minimum_msat; next_hops_path_penalty_msat = elem.path_penalty_msat; is_first_hop_target = elem.is_first_hop_target; + is_last_hop_target = elem.is_last_hop_target; was_processed } else { // Entries are added to dist in add_entry!() when there is a channel from a node. @@ -2720,17 +2808,28 @@ where L::Target: Logger { next_hops_path_htlc_minimum_msat = 0; next_hops_path_penalty_msat = 0; is_first_hop_target = false; + is_last_hop_target = false; false }; if !skip_node { + if is_last_hop_target { + if let Some(candidates) = last_hop_candidates.get(&$node_counter) { + for candidate in candidates { + add_entry!(candidate, fee_to_target_msat, + $next_hops_value_contribution, + next_hops_path_htlc_minimum_msat, next_hops_path_penalty_msat, + $next_hops_cltv_delta, $next_hops_path_length); + } + } + } if is_first_hop_target { if let Some((first_channels, peer_node_counter)) = first_hop_targets.get(&$node_id) { for details in first_channels { - debug_assert_eq!(*peer_node_counter, $node.node_counter); + debug_assert_eq!(*peer_node_counter, $node_counter); let candidate = CandidateRouteHop::FirstHop(FirstHopCandidate { details, payer_node_id: &our_node_id, payer_node_counter, - target_node_counter: $node.node_counter, + target_node_counter: $node_counter, }); add_entry!(&candidate, fee_to_target_msat, $next_hops_value_contribution, @@ -2740,29 +2839,31 @@ where L::Target: Logger { } } - let features = if let Some(node_info) = $node.announcement_info.as_ref() { - &node_info.features() - } else { - &default_node_features - }; + if let Some(node) = $node { + let features = if let Some(node_info) = node.announcement_info.as_ref() { + &node_info.features() + } else { + &default_node_features + }; - if !features.requires_unknown_bits() { - for chan_id in $node.channels.iter() { - let chan = network_channels.get(chan_id).unwrap(); - if !chan.features.requires_unknown_bits() { - if let Some((directed_channel, source)) = chan.as_directed_to(&$node_id) { - if first_hops.is_none() || *source != our_node_id { - if directed_channel.direction().enabled { - let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate { - info: directed_channel, - short_channel_id: *chan_id, - }); - add_entry!(&candidate, - fee_to_target_msat, - $next_hops_value_contribution, - next_hops_path_htlc_minimum_msat, - next_hops_path_penalty_msat, - $next_hops_cltv_delta, $next_hops_path_length); + if !features.requires_unknown_bits() { + for chan_id in node.channels.iter() { + let chan = network_channels.get(chan_id).unwrap(); + if !chan.features.requires_unknown_bits() { + if let Some((directed_channel, source)) = chan.as_directed_to(&$node_id) { + if first_hops.is_none() || *source != our_node_id { + if directed_channel.direction().enabled { + let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate { + info: directed_channel, + short_channel_id: *chan_id, + }); + add_entry!(&candidate, + fee_to_target_msat, + $next_hops_value_contribution, + next_hops_path_htlc_minimum_msat, + next_hops_path_penalty_msat, + $next_hops_cltv_delta, $next_hops_path_length); + } } } } @@ -2783,13 +2884,23 @@ where L::Target: Logger { for e in dist.iter_mut() { *e = None; } + + // Step (2). + // Add entries for first-hop and last-hop channel hints to `dist` and add the payee node as + // the best entry via `add_entry`. + // For first- and last-hop hints we need only add dummy entries in `dist` with the relevant + // flags set. As we walk the graph in `add_entries_to_cheapest_to_target_node` we'll check + // those flags and add the channels described by the hints. + // We then either add the payee using `add_entries_to_cheapest_to_target_node` or add the + // blinded paths to the payee using `add_entry`, filling `targets` and setting us up for + // our graph walk. for (_, (chans, peer_node_counter)) in first_hop_targets.iter() { // In order to avoid looking up whether each node is a first-hop target, we store a // dummy entry in dist for each first-hop target, allowing us to do this lookup for // free since we're already looking at the `was_processed` flag. // - // Note that all the fields (except `is_first_hop_target`) will be overwritten whenever - // we find a path to the target, so are left as dummies here. + // Note that all the fields (except `is_{first,last}_hop_target`) will be overwritten + // whenever we find a path to the target, so are left as dummies here. dist[*peer_node_counter as usize] = Some(PathBuildingHop { candidate: CandidateRouteHop::FirstHop(FirstHopCandidate { details: &chans[0], @@ -2805,50 +2916,66 @@ where L::Target: Logger { path_penalty_msat: u64::max_value(), was_processed: false, is_first_hop_target: true, - #[cfg(all(not(ldk_bench), any(test, fuzzing)))] + is_last_hop_target: false, value_contribution_msat: 0, + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] + best_path_from_hop_selected: false, }); } - hit_minimum_limit = false; - - // If first hop is a private channel and the only way to reach the payee, this is the only - // place where it could be added. - payee_node_id_opt.map(|payee| first_hop_targets.get(&payee).map(|(first_channels, peer_node_counter)| { - debug_assert_eq!(*peer_node_counter, payee_node_counter); - for details in first_channels { - let candidate = CandidateRouteHop::FirstHop(FirstHopCandidate { - details, payer_node_id: &our_node_id, payer_node_counter, - target_node_counter: payee_node_counter, + for (target_node_counter, candidates) in last_hop_candidates.iter() { + // In order to avoid looking up whether each node is a last-hop target, we store a + // dummy entry in dist for each last-hop target, allowing us to do this lookup for + // free since we're already looking at the `was_processed` flag. + // + // Note that all the fields (except `is_{first,last}_hop_target`) will be overwritten + // whenever we find a path to the target, so are left as dummies here. + debug_assert!(!candidates.is_empty()); + if candidates.is_empty() { continue } + let entry = &mut dist[**target_node_counter as usize]; + if let Some(hop) = entry { + hop.is_last_hop_target = true; + } else { + *entry = Some(PathBuildingHop { + candidate: candidates[0].clone(), + fee_msat: 0, + next_hops_fee_msat: u64::max_value(), + hop_use_fee_msat: u64::max_value(), + total_fee_msat: u64::max_value(), + path_htlc_minimum_msat: u64::max_value(), + path_penalty_msat: u64::max_value(), + was_processed: false, + is_first_hop_target: false, + is_last_hop_target: true, + value_contribution_msat: 0, + #[cfg(all(not(ldk_bench), any(test, fuzzing)))] + best_path_from_hop_selected: false, }); - let added = add_entry!(&candidate, 0, path_value_msat, - 0, 0u64, 0, 0).is_some(); - log_trace!(logger, "{} direct route to payee via {}", - if added { "Added" } else { "Skipped" }, LoggedCandidateHop(&candidate)); } - })); + } + hit_minimum_limit = false; - // Add the payee as a target, so that the payee-to-payer - // search algorithm knows what to start with. - payee_node_id_opt.map(|payee| match network_nodes.get(&payee) { - // The payee is not in our network graph, so nothing to add here. - // There is still a chance of reaching them via last_hops though, - // so don't yet fail the payment here. - // If not, targets.pop() will not even let us enter the loop in step 2. - None => {}, - Some(node) => { - add_entries_to_cheapest_to_target_node!(node, payee, path_value_msat, 0, 0); - }, - }); + if let Some(payee) = payee_node_id_opt { + if let Some(entry) = &mut dist[payee_node_counter as usize] { + // If we built a dummy entry above we need to reset the values to represent 0 fee + // from the target "to the target". + entry.next_hops_fee_msat = 0; + entry.hop_use_fee_msat = 0; + entry.total_fee_msat = 0; + entry.path_htlc_minimum_msat = 0; + entry.path_penalty_msat = 0; + entry.value_contribution_msat = path_value_msat; + } + add_entries_to_cheapest_to_target_node!( + network_nodes.get(&payee), payee_node_counter, payee, path_value_msat, 0, 0 + ); + } - // Step (2). - // If a caller provided us with last hops, add them to routing targets. Since this happens - // earlier than general path finding, they will be somewhat prioritized, although currently - // it matters only if the fees are exactly the same. debug_assert_eq!( payment_params.payee.blinded_route_hints().len(), introduction_node_id_cache.len(), "introduction_node_id_cache was built by iterating the blinded_route_hints, so they should be the same len" ); + let mut blind_intros_added = hash_map_with_capacity(payment_params.payee.blinded_route_hints().len()); for (hint_idx, hint) in payment_params.payee.blinded_route_hints().iter().enumerate() { // Only add the hops in this route to our candidate set if either // we have a direct channel to the first hop or the first hop is @@ -2863,12 +2990,21 @@ where L::Target: Logger { } else { CandidateRouteHop::Blinded(BlindedPathCandidate { source_node_counter, source_node_id, hint, hint_idx }) }; - let mut path_contribution_msat = path_value_msat; if let Some(hop_used_msat) = add_entry!(&candidate, - 0, path_contribution_msat, 0, 0_u64, 0, 0) + 0, path_value_msat, 0, 0_u64, 0, 0) { - path_contribution_msat = hop_used_msat; + blind_intros_added.insert(source_node_id, (hop_used_msat, candidate)); } else { continue } + } + // If we added a blinded path from an introduction node to the destination, where the + // introduction node is one of our direct peers, we need to scan our `first_channels` + // to detect this. However, doing so immediately after calling `add_entry`, above, could + // result in incorrect behavior if we, in a later loop iteration, update the fee from the + // same introduction point to the destination (due to a different blinded path with the + // same introduction point having a lower score). + // Thus, we track the nodes that we added paths from in `blind_intros_added` and scan for + // introduction points we have a channel with after processing all blinded paths. + for (source_node_id, (path_contribution_msat, candidate)) in blind_intros_added { if let Some((first_channels, peer_node_counter)) = first_hop_targets.get_mut(source_node_id) { sort_first_hop_channels( first_channels, &used_liquidities, recommended_value_msat, our_node_pubkey @@ -2889,165 +3025,6 @@ where L::Target: Logger { } } } - for route in payment_params.payee.unblinded_route_hints().iter() - .filter(|route| !route.0.is_empty()) - { - let first_hop_src_id = NodeId::from_pubkey(&route.0.first().unwrap().src_node_id); - let first_hop_src_is_reachable = - // Only add the hops in this route to our candidate set if either we are part of - // the first hop, we have a direct channel to the first hop, or the first hop is in - // the regular network graph. - our_node_id == first_hop_src_id || - first_hop_targets.get(&first_hop_src_id).is_some() || - network_nodes.get(&first_hop_src_id).is_some(); - if first_hop_src_is_reachable { - // We start building the path from reverse, i.e., from payee - // to the first RouteHintHop in the path. - let hop_iter = route.0.iter().rev(); - let prev_hop_iter = core::iter::once(&maybe_dummy_payee_pk).chain( - route.0.iter().skip(1).rev().map(|hop| &hop.src_node_id)); - let mut hop_used = true; - let mut aggregate_next_hops_fee_msat: u64 = 0; - let mut aggregate_next_hops_path_htlc_minimum_msat: u64 = 0; - let mut aggregate_next_hops_path_penalty_msat: u64 = 0; - let mut aggregate_next_hops_cltv_delta: u32 = 0; - let mut aggregate_next_hops_path_length: u8 = 0; - let mut aggregate_path_contribution_msat = path_value_msat; - - for (idx, (hop, prev_hop_id)) in hop_iter.zip(prev_hop_iter).enumerate() { - let (target, private_target_node_counter) = - node_counters.private_node_counter_from_pubkey(&prev_hop_id) - .expect("node_counter_from_pubkey is called on all unblinded_route_hints keys during setup, so is always Some here"); - let (_src_id, private_source_node_counter) = - node_counters.private_node_counter_from_pubkey(&hop.src_node_id) - .expect("node_counter_from_pubkey is called on all unblinded_route_hints keys during setup, so is always Some here"); - - if let Some((first_channels, _)) = first_hop_targets.get(target) { - if first_channels.iter().any(|d| d.outbound_scid_alias == Some(hop.short_channel_id)) { - log_trace!(logger, "Ignoring route hint with SCID {} (and any previous) due to it being a direct channel of ours.", - hop.short_channel_id); - break; - } - } - - let candidate = network_channels - .get(&hop.short_channel_id) - .and_then(|channel| channel.as_directed_to(target)) - .map(|(info, _)| CandidateRouteHop::PublicHop(PublicHopCandidate { - info, - short_channel_id: hop.short_channel_id, - })) - .unwrap_or_else(|| CandidateRouteHop::PrivateHop(PrivateHopCandidate { - hint: hop, target_node_id: target, - source_node_counter: *private_source_node_counter, - target_node_counter: *private_target_node_counter, - })); - - if let Some(hop_used_msat) = add_entry!(&candidate, - aggregate_next_hops_fee_msat, aggregate_path_contribution_msat, - aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, - aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length) - { - aggregate_path_contribution_msat = hop_used_msat; - } else { - // If this hop was not used then there is no use checking the preceding - // hops in the RouteHint. We can break by just searching for a direct - // channel between last checked hop and first_hop_targets. - hop_used = false; - } - - let used_liquidity_msat = used_liquidities - .get(&candidate.id()).copied() - .unwrap_or(0); - let channel_usage = ChannelUsage { - amount_msat: final_value_msat + aggregate_next_hops_fee_msat, - inflight_htlc_msat: used_liquidity_msat, - effective_capacity: candidate.effective_capacity(), - }; - let channel_penalty_msat = scorer.channel_penalty_msat( - &candidate, channel_usage, score_params - ); - aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat - .saturating_add(channel_penalty_msat); - - aggregate_next_hops_cltv_delta = aggregate_next_hops_cltv_delta - .saturating_add(hop.cltv_expiry_delta as u32); - - aggregate_next_hops_path_length = aggregate_next_hops_path_length - .saturating_add(1); - - // Searching for a direct channel between last checked hop and first_hop_targets - if let Some((first_channels, peer_node_counter)) = first_hop_targets.get_mut(target) { - sort_first_hop_channels( - first_channels, &used_liquidities, recommended_value_msat, our_node_pubkey - ); - for details in first_channels { - let first_hop_candidate = CandidateRouteHop::FirstHop(FirstHopCandidate { - details, payer_node_id: &our_node_id, payer_node_counter, - target_node_counter: *peer_node_counter, - }); - add_entry!(&first_hop_candidate, - aggregate_next_hops_fee_msat, aggregate_path_contribution_msat, - aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat, - aggregate_next_hops_cltv_delta, aggregate_next_hops_path_length); - } - } - - if !hop_used { - break; - } - - // In the next values of the iterator, the aggregate fees already reflects - // the sum of value sent from payer (final_value_msat) and routing fees - // for the last node in the RouteHint. We need to just add the fees to - // route through the current node so that the preceding node (next iteration) - // can use it. - let hops_fee = compute_fees(aggregate_next_hops_fee_msat + final_value_msat, hop.fees) - .map_or(None, |inc| inc.checked_add(aggregate_next_hops_fee_msat)); - aggregate_next_hops_fee_msat = if let Some(val) = hops_fee { val } else { break; }; - - // The next channel will need to relay this channel's min_htlc *plus* the fees taken by - // this route hint's source node to forward said min over this channel. - aggregate_next_hops_path_htlc_minimum_msat = { - let curr_htlc_min = cmp::max( - candidate.htlc_minimum_msat(), aggregate_next_hops_path_htlc_minimum_msat - ); - let curr_htlc_min_fee = if let Some(val) = compute_fees(curr_htlc_min, hop.fees) { val } else { break }; - if let Some(min) = curr_htlc_min.checked_add(curr_htlc_min_fee) { min } else { break } - }; - - if idx == route.0.len() - 1 { - // The last hop in this iterator is the first hop in - // overall RouteHint. - // If this hop connects to a node with which we have a direct channel, - // ignore the network graph and, if the last hop was added, add our - // direct channel to the candidate set. - // - // Note that we *must* check if the last hop was added as `add_entry` - // always assumes that the third argument is a node to which we have a - // path. - if let Some((first_channels, peer_node_counter)) = first_hop_targets.get_mut(&NodeId::from_pubkey(&hop.src_node_id)) { - sort_first_hop_channels( - first_channels, &used_liquidities, recommended_value_msat, our_node_pubkey - ); - for details in first_channels { - let first_hop_candidate = CandidateRouteHop::FirstHop(FirstHopCandidate { - details, payer_node_id: &our_node_id, payer_node_counter, - target_node_counter: *peer_node_counter, - }); - add_entry!(&first_hop_candidate, - aggregate_next_hops_fee_msat, - aggregate_path_contribution_msat, - aggregate_next_hops_path_htlc_minimum_msat, - aggregate_next_hops_path_penalty_msat, - aggregate_next_hops_cltv_delta, - aggregate_next_hops_path_length); - } - } - } - } - } - } log_trace!(logger, "Starting main path collection loop with {} nodes pre-filled from first/last hops.", targets.len()); @@ -3064,7 +3041,7 @@ where L::Target: Logger { // Both these cases (and other cases except reaching recommended_value_msat) mean that // paths_collection will be stopped because found_new_path==false. // This is not necessarily a routing failure. - 'path_construction: while let Some(RouteGraphNode { node_id, total_cltv_delta, mut value_contribution_msat, path_length_to_node, .. }) = targets.pop() { + 'path_construction: while let Some(RouteGraphNode { node_id, node_counter, total_cltv_delta, mut value_contribution_msat, path_length_to_node, .. }) = targets.pop() { // Since we're going payee-to-payer, hitting our node as a target means we should stop // traversing the graph and arrange the path out of what we found. @@ -3199,14 +3176,11 @@ where L::Target: Logger { // Otherwise, since the current target node is not us, // keep "unrolling" the payment graph from payee to payer by // finding a way to reach the current target from the payer side. - match network_nodes.get(&node_id) { - None => {}, - Some(node) => { - add_entries_to_cheapest_to_target_node!(node, node_id, - value_contribution_msat, - total_cltv_delta, path_length_to_node); - }, - } + add_entries_to_cheapest_to_target_node!( + network_nodes.get(&node_id), node_counter, node_id, + value_contribution_msat, + total_cltv_delta, path_length_to_node + ); } if !allow_mpp { @@ -5733,187 +5707,33 @@ mod tests { } #[test] - fn long_mpp_route_test() { - let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph(); - let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = ln_test_utils::TestScorer::new(); - let random_seed_bytes = [42; 32]; - let config = UserConfig::default(); - let payment_params = PaymentParameters::from_node_id(nodes[3], 42) - .with_bolt11_features(channelmanager::provided_bolt11_invoice_features(&config)) - .unwrap(); - - // We need a route consisting of 3 paths: - // From our node to node3 via {node0, node2}, {node7, node2, node4} and {node7, node2}. - // Note that these paths overlap (channels 5, 12, 13). - // We will route 300 sats. - // Each path will have 100 sats capacity, those channels which - // are used twice will have 200 sats capacity. - - // Disable other potential paths. - update_channel(&gossip_sync, &secp_ctx, &our_privkey, UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 2, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 2, - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 100_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - update_channel(&gossip_sync, &secp_ctx, &privkeys[2], UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 7, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 2, - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 100_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - - // Path via {node0, node2} is channels {1, 3, 5}. - update_channel(&gossip_sync, &secp_ctx, &our_privkey, UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 1, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 0, - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 100_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - update_channel(&gossip_sync, &secp_ctx, &privkeys[0], UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 3, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 0, - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 100_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - - // Capacity of 200 sats because this channel will be used by 3rd path as well. - add_channel(&gossip_sync, &secp_ctx, &privkeys[2], &privkeys[3], ChannelFeatures::from_le_bytes(id_to_feature_flags(5)), 5); - update_channel(&gossip_sync, &secp_ctx, &privkeys[2], UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 5, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 0, - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 200_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - update_channel(&gossip_sync, &secp_ctx, &privkeys[3], UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 5, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 3, // disable direction 1 - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 200_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - - // Path via {node7, node2, node4} is channels {12, 13, 6, 11}. - // Add 100 sats to the capacities of {12, 13}, because these channels - // are also used for 3rd path. 100 sats for the rest. Total capacity: 100 sats. - update_channel(&gossip_sync, &secp_ctx, &our_privkey, UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 12, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 0, - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 200_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - update_channel(&gossip_sync, &secp_ctx, &privkeys[7], UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 13, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 0, - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 200_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - - update_channel(&gossip_sync, &secp_ctx, &privkeys[2], UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 6, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 0, - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 100_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - update_channel(&gossip_sync, &secp_ctx, &privkeys[4], UnsignedChannelUpdate { - chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 11, - timestamp: 2, - message_flags: 1, // Only must_be_one - channel_flags: 0, - cltv_expiry_delta: 0, - htlc_minimum_msat: 0, - htlc_maximum_msat: 100_000, - fee_base_msat: 0, - fee_proportional_millionths: 0, - excess_data: Vec::new() - }); - - // Path via {node7, node2} is channels {12, 13, 5}. - // We already limited them to 200 sats (they are used twice for 100 sats). - // Nothing to do here. - + fn mpp_tests() { + let secp_ctx = Secp256k1::new(); + let (_, _, _, nodes) = get_nodes(&secp_ctx); { - // Attempt to route more than available results in a failure. - let route_params = RouteParameters::from_payment_params_and_value( - payment_params.clone(), 350_000); - if let Err(LightningError{err, action: ErrorAction::IgnoreError}) = get_route( - &our_id, &route_params, &network_graph.read_only(), None, Arc::clone(&logger), - &scorer, &Default::default(), &random_seed_bytes) { - assert_eq!(err, "Failed to find a sufficient route to the given destination"); - } else { panic!(); } - } + // Check that if we have two cheaper paths and a more expensive (fewer hops) path, we + // choose the two cheaper paths: + let route = do_mpp_route_tests(180_000).unwrap(); + assert_eq!(route.paths.len(), 2); + let mut total_value_transferred_msat = 0; + let mut total_paid_msat = 0; + for path in &route.paths { + assert_eq!(path.hops.last().unwrap().pubkey, nodes[3]); + total_value_transferred_msat += path.final_value_msat(); + for hop in &path.hops { + total_paid_msat += hop.fee_msat; + } + } + // If we paid fee, this would be higher. + assert_eq!(total_value_transferred_msat, 180_000); + let total_fees_paid = total_paid_msat - total_value_transferred_msat; + assert_eq!(total_fees_paid, 0); + } { - // Now, attempt to route 300 sats (exact amount we can route). - // Our algorithm should provide us with these 3 paths, 100 sats each. - let route_params = RouteParameters::from_payment_params_and_value( - payment_params, 300_000); - let route = get_route(&our_id, &route_params, &network_graph.read_only(), None, - Arc::clone(&logger), &scorer, &Default::default(), &random_seed_bytes).unwrap(); + // Check that if we use the same channels but need to send more than we could fit in + // the cheaper paths we select all three paths: + let route = do_mpp_route_tests(300_000).unwrap(); assert_eq!(route.paths.len(), 3); let mut total_amount_paid_msat = 0; @@ -5923,11 +5743,11 @@ mod tests { } assert_eq!(total_amount_paid_msat, 300_000); } - + // Check that trying to pay more than our available liquidity fails. + assert!(do_mpp_route_tests(300_001).is_err()); } - #[test] - fn mpp_cheaper_route_test() { + fn do_mpp_route_tests(amt: u64) -> Result { let (secp_ctx, network_graph, gossip_sync, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); let scorer = ln_test_utils::TestScorer::new(); @@ -5937,21 +5757,17 @@ mod tests { .with_bolt11_features(channelmanager::provided_bolt11_invoice_features(&config)) .unwrap(); - // This test checks that if we have two cheaper paths and one more expensive path, - // so that liquidity-wise any 2 of 3 combination is sufficient, - // two cheaper paths will be taken. - // These paths have equal available liquidity. - - // We need a combination of 3 paths: - // From our node to node3 via {node0, node2}, {node7, node2, node4} and {node7, node2}. - // Note that these paths overlap (channels 5, 12, 13). - // Each path will have 100 sats capacity, those channels which - // are used twice will have 200 sats capacity. + // Build a setup where we have three potential paths from us to node3: + // {node0, node2, node4} (channels 1, 3, 6, 11), fee 0 msat, + // {node7, node2, node4} (channels 12, 13, 6, 11), fee 0 msat, and + // {node1} (channel 2, then a new channel 16), fee 1000 msat. + // Note that these paths overlap on channels 6 and 11. + // Each channel will have 100 sats capacity except for 6 and 11, which have 200. // Disable other potential paths. - update_channel(&gossip_sync, &secp_ctx, &our_privkey, UnsignedChannelUpdate { + update_channel(&gossip_sync, &secp_ctx, &privkeys[2], UnsignedChannelUpdate { chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 2, + short_channel_id: 7, timestamp: 2, message_flags: 1, // Only must_be_one channel_flags: 2, @@ -5962,9 +5778,9 @@ mod tests { fee_proportional_millionths: 0, excess_data: Vec::new() }); - update_channel(&gossip_sync, &secp_ctx, &privkeys[2], UnsignedChannelUpdate { + update_channel(&gossip_sync, &secp_ctx, &privkeys[1], UnsignedChannelUpdate { chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 7, + short_channel_id: 4, timestamp: 2, message_flags: 1, // Only must_be_one channel_flags: 2, @@ -6004,31 +5820,30 @@ mod tests { excess_data: Vec::new() }); - // Capacity of 200 sats because this channel will be used by 3rd path as well. - add_channel(&gossip_sync, &secp_ctx, &privkeys[2], &privkeys[3], ChannelFeatures::from_le_bytes(id_to_feature_flags(5)), 5); - update_channel(&gossip_sync, &secp_ctx, &privkeys[2], UnsignedChannelUpdate { + add_channel(&gossip_sync, &secp_ctx, &privkeys[1], &privkeys[3], ChannelFeatures::from_le_bytes(id_to_feature_flags(16)), 16); + update_channel(&gossip_sync, &secp_ctx, &privkeys[1], UnsignedChannelUpdate { chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 5, + short_channel_id: 16, timestamp: 2, message_flags: 1, // Only must_be_one channel_flags: 0, cltv_expiry_delta: 0, htlc_minimum_msat: 0, - htlc_maximum_msat: 200_000, - fee_base_msat: 0, + htlc_maximum_msat: 100_000, + fee_base_msat: 1_000, fee_proportional_millionths: 0, excess_data: Vec::new() }); update_channel(&gossip_sync, &secp_ctx, &privkeys[3], UnsignedChannelUpdate { chain_hash: ChainHash::using_genesis_block(Network::Testnet), - short_channel_id: 5, + short_channel_id: 16, timestamp: 2, message_flags: 1, // Only must_be_one channel_flags: 3, // disable direction 1 cltv_expiry_delta: 0, htlc_minimum_msat: 0, - htlc_maximum_msat: 200_000, - fee_base_msat: 0, + htlc_maximum_msat: 100_000, + fee_base_msat: 1_000, fee_proportional_millionths: 0, excess_data: Vec::new() }); @@ -6044,7 +5859,7 @@ mod tests { channel_flags: 0, cltv_expiry_delta: 0, htlc_minimum_msat: 0, - htlc_maximum_msat: 200_000, + htlc_maximum_msat: 100_000, fee_base_msat: 0, fee_proportional_millionths: 0, excess_data: Vec::new() @@ -6057,7 +5872,7 @@ mod tests { channel_flags: 0, cltv_expiry_delta: 0, htlc_minimum_msat: 0, - htlc_maximum_msat: 200_000, + htlc_maximum_msat: 100_000, fee_base_msat: 0, fee_proportional_millionths: 0, excess_data: Vec::new() @@ -6071,8 +5886,8 @@ mod tests { channel_flags: 0, cltv_expiry_delta: 0, htlc_minimum_msat: 0, - htlc_maximum_msat: 100_000, - fee_base_msat: 1_000, + htlc_maximum_msat: 200_000, + fee_base_msat: 0, fee_proportional_millionths: 0, excess_data: Vec::new() }); @@ -6084,7 +5899,7 @@ mod tests { channel_flags: 0, cltv_expiry_delta: 0, htlc_minimum_msat: 0, - htlc_maximum_msat: 100_000, + htlc_maximum_msat: 200_000, fee_base_msat: 0, fee_proportional_millionths: 0, excess_data: Vec::new() @@ -6094,29 +5909,11 @@ mod tests { // We already limited them to 200 sats (they are used twice for 100 sats). // Nothing to do here. - { - // Now, attempt to route 180 sats. - // Our algorithm should provide us with these 2 paths. - let route_params = RouteParameters::from_payment_params_and_value( - payment_params, 180_000); - let route = get_route(&our_id, &route_params, &network_graph.read_only(), None, - Arc::clone(&logger), &scorer, &Default::default(), &random_seed_bytes).unwrap(); - assert_eq!(route.paths.len(), 2); - - let mut total_value_transferred_msat = 0; - let mut total_paid_msat = 0; - for path in &route.paths { - assert_eq!(path.hops.last().unwrap().pubkey, nodes[3]); - total_value_transferred_msat += path.final_value_msat(); - for hop in &path.hops { - total_paid_msat += hop.fee_msat; - } - } - // If we paid fee, this would be higher. - assert_eq!(total_value_transferred_msat, 180_000); - let total_fees_paid = total_paid_msat - total_value_transferred_msat; - assert_eq!(total_fees_paid, 0); - } + let route_params = RouteParameters::from_payment_params_and_value( + payment_params, amt); + let res = get_route(&our_id, &route_params, &network_graph.read_only(), None, + Arc::clone(&logger), &scorer, &Default::default(), &random_seed_bytes); + res } #[test] diff --git a/lightning/src/routing/scoring.rs b/lightning/src/routing/scoring.rs index 0755a21e1ae..506f70d05b4 100644 --- a/lightning/src/routing/scoring.rs +++ b/lightning/src/routing/scoring.rs @@ -1795,13 +1795,21 @@ mod bucketed_history { self.buckets[bucket] = self.buckets[bucket].saturating_add(BUCKET_FIXED_POINT_ONE); } } + + /// Applies decay at the given half-life to all buckets. + fn decay(&mut self, half_lives: f64) { + let factor = (1024.0 * powf64(0.5, half_lives)) as u64; + for bucket in self.buckets.iter_mut() { + *bucket = ((*bucket as u64) * factor / 1024) as u16; + } + } } impl_writeable_tlv_based!(HistoricalBucketRangeTracker, { (0, buckets, required) }); impl_writeable_tlv_based!(LegacyHistoricalBucketRangeTracker, { (0, buckets, required) }); #[derive(Clone, Copy)] - #[repr(C)] // Force the fields in memory to be in the order we specify. + #[repr(C)]// Force the fields in memory to be in the order we specify. pub(super) struct HistoricalLiquidityTracker { // This struct sits inside a `(u64, ChannelLiquidity)` in memory, and we first read the // liquidity offsets in `ChannelLiquidity` when calculating the non-historical score. This @@ -1849,26 +1857,23 @@ mod bucketed_history { } pub(super) fn decay_buckets(&mut self, half_lives: f64) { - let divisor = powf64(2048.0, half_lives) as u64; - for bucket in self.min_liquidity_offset_history.buckets.iter_mut() { - *bucket = ((*bucket as u64) * 1024 / divisor) as u16; - } - for bucket in self.max_liquidity_offset_history.buckets.iter_mut() { - *bucket = ((*bucket as u64) * 1024 / divisor) as u16; - } + self.min_liquidity_offset_history.decay(half_lives); + self.max_liquidity_offset_history.decay(half_lives); self.recalculate_valid_point_count(); } fn recalculate_valid_point_count(&mut self) { - let mut total_valid_points_tracked = 0; + let mut total_valid_points_tracked = 0u128; for (min_idx, min_bucket) in self.min_liquidity_offset_history.buckets.iter().enumerate() { for max_bucket in self.max_liquidity_offset_history.buckets.iter().take(32 - min_idx) { // In testing, raising the weights of buckets to a high power led to better // scoring results. Thus, we raise the bucket weights to the 4th power here (by - // squaring the result of multiplying the weights). + // squaring the result of multiplying the weights). This results in + // bucket_weight having at max 64 bits, which means we have to do our summation + // in 128-bit math. let mut bucket_weight = (*min_bucket as u64) * (*max_bucket as u64); bucket_weight *= bucket_weight; - total_valid_points_tracked += bucket_weight; + total_valid_points_tracked += bucket_weight as u128; } } self.total_valid_points_tracked = total_valid_points_tracked as f64; @@ -1954,12 +1959,12 @@ mod bucketed_history { let total_valid_points_tracked = self.tracker.total_valid_points_tracked; #[cfg(debug_assertions)] { - let mut actual_valid_points_tracked = 0; + let mut actual_valid_points_tracked = 0u128; for (min_idx, min_bucket) in min_liquidity_offset_history_buckets.iter().enumerate() { for max_bucket in max_liquidity_offset_history_buckets.iter().take(32 - min_idx) { let mut bucket_weight = (*min_bucket as u64) * (*max_bucket as u64); bucket_weight *= bucket_weight; - actual_valid_points_tracked += bucket_weight; + actual_valid_points_tracked += bucket_weight as u128; } } assert_eq!(total_valid_points_tracked, actual_valid_points_tracked as f64); @@ -1986,7 +1991,7 @@ mod bucketed_history { // max-bucket with at least BUCKET_FIXED_POINT_ONE. let mut highest_max_bucket_with_points = 0; let mut highest_max_bucket_with_full_points = None; - let mut total_weight = 0; + let mut total_weight = 0u128; for (max_idx, max_bucket) in max_liquidity_offset_history_buckets.iter().enumerate() { if *max_bucket >= BUCKET_FIXED_POINT_ONE { highest_max_bucket_with_full_points = Some(cmp::max(highest_max_bucket_with_full_points.unwrap_or(0), max_idx)); @@ -1999,7 +2004,7 @@ mod bucketed_history { // squaring the result of multiplying the weights), matching the logic in // `recalculate_valid_point_count`. let bucket_weight = (*max_bucket as u64) * (min_liquidity_offset_history_buckets[0] as u64); - total_weight += bucket_weight * bucket_weight; + total_weight += (bucket_weight * bucket_weight) as u128; } debug_assert!(total_weight as f64 <= total_valid_points_tracked); // Use the highest max-bucket with at least BUCKET_FIXED_POINT_ONE, but if none is @@ -2049,6 +2054,50 @@ mod bucketed_history { Some((cumulative_success_prob * (1024.0 * 1024.0 * 1024.0)) as u64) } } + + #[cfg(test)] + mod tests { + use super::{HistoricalBucketRangeTracker, HistoricalLiquidityTracker, ProbabilisticScoringFeeParameters}; + + #[test] + fn historical_liquidity_bucket_decay() { + let mut bucket = HistoricalBucketRangeTracker::new(); + bucket.track_datapoint(100, 1000); + assert_eq!( + bucket.buckets, + [ + 0u16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0 + ] + ); + + bucket.decay(2.0); + assert_eq!( + bucket.buckets, + [ + 0u16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0 + ] + ); + } + + #[test] + fn historical_heavy_buckets_operations() { + // Checks that we don't hit overflows when working with tons of data (even an + // impossible-to-reach amount of data). + let mut tracker = HistoricalLiquidityTracker::new(); + tracker.min_liquidity_offset_history.buckets = [0xffff; 32]; + tracker.max_liquidity_offset_history.buckets = [0xffff; 32]; + tracker.recalculate_valid_point_count(); + + let mut directed = tracker.as_directed_mut(true); + let default_params = ProbabilisticScoringFeeParameters::default(); + directed.calculate_success_probability_times_billion(&default_params, 42, 1000); + directed.track_datapoint(42, 52, 1000); + + tracker.decay_buckets(1.0); + } + } } use bucketed_history::{LegacyHistoricalBucketRangeTracker, HistoricalBucketRangeTracker, DirectedHistoricalLiquidityTracker, HistoricalLiquidityTracker};