Skip to content

Commit

Permalink
add combined scorer
Browse files Browse the repository at this point in the history
  • Loading branch information
joostjager committed Jan 30, 2025
1 parent b3e7ab0 commit d6caa86
Showing 1 changed file with 232 additions and 2 deletions.
234 changes: 232 additions & 2 deletions lightning/src/routing/scoring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,7 @@ where L::Target: Logger {
channel_liquidities: ChannelLiquidities,
}
/// Container for live and historical liquidity bounds for each channel.
#[derive(Clone)]
pub struct ChannelLiquidities(HashMap<u64, ChannelLiquidity>);

impl ChannelLiquidities {
Expand Down Expand Up @@ -886,6 +887,7 @@ impl ProbabilisticScoringDecayParameters {
/// first node in the ordering of the channel's counterparties. Thus, swapping the two liquidity
/// offset fields gives the opposite direction.
#[repr(C)] // Force the fields in memory to be in the order we specify
#[derive(Clone)]
pub struct ChannelLiquidity {
/// Lower channel liquidity bound in terms of an offset from zero.
min_liquidity_offset_msat: u64,
Expand Down Expand Up @@ -1156,6 +1158,15 @@ impl ChannelLiquidity {
}
}

fn merge(&mut self, other: &Self) {
// Take average for min/max liquidity offsets.
self.min_liquidity_offset_msat = (self.min_liquidity_offset_msat + other.min_liquidity_offset_msat) / 2;
self.max_liquidity_offset_msat = (self.max_liquidity_offset_msat + other.max_liquidity_offset_msat) / 2;

// Merge historical liquidity data.
self.liquidity_history.merge(&other.liquidity_history);
}

/// Returns a view of the channel liquidity directed from `source` to `target` assuming
/// `capacity_msat`.
fn as_directed(
Expand Down Expand Up @@ -1689,6 +1700,91 @@ impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreUpdate for Probabilistic
}
}

/// A probabilistic scorer that combines local and external information to score channels. This scorer is
/// shadow-tracking local only scores, so that it becomes possible to cleanly merge external scores when they become
/// available.
pub struct CombinedScorer<G: Deref<Target = NetworkGraph<L>>, L: Deref> where L::Target: Logger {
local_only_scorer: ProbabilisticScorer<G, L>,
scorer: ProbabilisticScorer<G, L>,
}

impl<G: Deref<Target = NetworkGraph<L>> + Clone, L: Deref + Clone> CombinedScorer<G, L> where L::Target: Logger {
/// Create a new combined scorer with the given local scorer.
pub fn new(local_scorer: ProbabilisticScorer<G, L>) -> Self {
let decay_params = local_scorer.decay_params;
let network_graph = local_scorer.network_graph.clone();
let logger = local_scorer.logger.clone();
let mut scorer = ProbabilisticScorer::new(decay_params, network_graph, logger);

scorer.channel_liquidities = local_scorer.channel_liquidities.clone();

Self {
local_only_scorer: local_scorer,
scorer: scorer,
}
}

/// Merge external channel liquidity information into the scorer.
pub fn merge(&mut self, mut external_scores: ChannelLiquidities, duration_since_epoch: Duration) {
// Decay both sets of scores to make them comparable and mergeable.
self.local_only_scorer.time_passed(duration_since_epoch);
external_scores.time_passed(duration_since_epoch, self.local_only_scorer.decay_params);

let local_scores = &self.local_only_scorer.channel_liquidities;

// For each channel, merge the external liquidity information with the isolated local liquidity information.
for (scid, mut liquidity) in external_scores.0 {
if let Some(local_liquidity) = local_scores.get(&scid) {
liquidity.merge(local_liquidity);
}
self.scorer.channel_liquidities.insert(scid, liquidity);
}
}
}

impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreLookUp for CombinedScorer<G, L> where L::Target: Logger {
type ScoreParams = ProbabilisticScoringFeeParameters;

fn channel_penalty_msat(
&self, candidate: &CandidateRouteHop, usage: ChannelUsage, score_params: &ProbabilisticScoringFeeParameters
) -> u64 {
self.scorer.channel_penalty_msat(candidate, usage, score_params)
}
}

impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> ScoreUpdate for CombinedScorer<G, L> where L::Target: Logger {
fn payment_path_failed(&mut self,path: &Path,short_channel_id:u64,duration_since_epoch:Duration) {
self.local_only_scorer.payment_path_failed(path, short_channel_id, duration_since_epoch);
self.scorer.payment_path_failed(path, short_channel_id, duration_since_epoch);
}

fn payment_path_successful(&mut self,path: &Path,duration_since_epoch:Duration) {
self.local_only_scorer.payment_path_successful(path, duration_since_epoch);
self.scorer.payment_path_successful(path, duration_since_epoch);
}

fn probe_failed(&mut self,path: &Path,short_channel_id:u64,duration_since_epoch:Duration) {
self.local_only_scorer.probe_failed(path, short_channel_id, duration_since_epoch);
self.scorer.probe_failed(path, short_channel_id, duration_since_epoch);
}

fn probe_successful(&mut self,path: &Path,duration_since_epoch:Duration) {
self.local_only_scorer.probe_successful(path, duration_since_epoch);
self.scorer.probe_successful(path, duration_since_epoch);
}

fn time_passed(&mut self,duration_since_epoch:Duration) {
self.local_only_scorer.time_passed(duration_since_epoch);
self.scorer.time_passed(duration_since_epoch);
}
}

impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Writeable for CombinedScorer<G, L> where L::Target: Logger {
fn write<W: crate::util::ser::Writer>(&self, writer: &mut W) -> Result<(), crate::io::Error> {
self.local_only_scorer.write(writer)
}
}

#[cfg(c_bindings)]
impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Score for ProbabilisticScorer<G, L>
where L::Target: Logger {}
Expand Down Expand Up @@ -1868,6 +1964,13 @@ mod bucketed_history {
self.buckets[bucket] = self.buckets[bucket].saturating_add(BUCKET_FIXED_POINT_ONE);
}
}

/// Returns the average of the buckets between the two trackers.
pub(crate) fn merge(&mut self, other: &Self) -> () {
for (index, bucket) in self.buckets.iter_mut().enumerate() {
*bucket = (*bucket + other.buckets[index]) / 2;
}
}
}

impl_writeable_tlv_based!(HistoricalBucketRangeTracker, { (0, buckets, required) });
Expand Down Expand Up @@ -1964,6 +2067,13 @@ mod bucketed_history {
-> DirectedHistoricalLiquidityTracker<&'a mut HistoricalLiquidityTracker> {
DirectedHistoricalLiquidityTracker { source_less_than_target, tracker: self }
}

/// Merges the historical liquidity data from another tracker into this one.
pub fn merge(&mut self, other: &Self) {
self.min_liquidity_offset_history.merge(&other.min_liquidity_offset_history);
self.max_liquidity_offset_history.merge(&other.max_liquidity_offset_history);
self.recalculate_valid_point_count();
}
}

/// A set of buckets representing the history of where we've seen the minimum- and maximum-
Expand Down Expand Up @@ -2122,6 +2232,72 @@ mod bucketed_history {
Some((cumulative_success_prob * (1024.0 * 1024.0 * 1024.0)) as u64)
}
}

#[cfg(test)]
mod tests {
use crate::routing::scoring::ProbabilisticScoringFeeParameters;

use super::{HistoricalBucketRangeTracker, HistoricalLiquidityTracker};
#[test]
fn historical_liquidity_bucket_merge() {
let mut bucket1 = HistoricalBucketRangeTracker::new();
bucket1.track_datapoint(100, 1000);
assert_eq!(
bucket1.buckets,
[
0u16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0
]
);

let mut bucket2 = HistoricalBucketRangeTracker::new();
bucket2.track_datapoint(0, 1000);
assert_eq!(
bucket2.buckets,
[
32u16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0
]
);

bucket1.merge(&bucket2);
assert_eq!(
bucket1.buckets,
[
16u16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0
]
);
}

#[test]
fn historical_liquidity_tracker_merge() {
let params = ProbabilisticScoringFeeParameters::default();

let probability1: Option<u64>;
let mut tracker1 = HistoricalLiquidityTracker::new();
{
let mut directed_tracker1 = tracker1.as_directed_mut(true);
directed_tracker1.track_datapoint(100, 200, 1000);
probability1 = directed_tracker1
.calculate_success_probability_times_billion(&params, 500, 1000);
}

let mut tracker2 = HistoricalLiquidityTracker::new();
{
let mut directed_tracker2 = tracker2.as_directed_mut(true);
directed_tracker2.track_datapoint(200, 300, 1000);
}

tracker1.merge(&tracker2);

let directed_tracker1 = tracker1.as_directed(true);
let probability =
directed_tracker1.calculate_success_probability_times_billion(&params, 500, 1000);

assert_ne!(probability1, probability);
}
}
}

impl<G: Deref<Target = NetworkGraph<L>>, L: Deref> Writeable for ProbabilisticScorer<G, L> where L::Target: Logger {
Expand Down Expand Up @@ -2215,15 +2391,15 @@ impl Readable for ChannelLiquidity {

#[cfg(test)]
mod tests {
use super::{ChannelLiquidity, HistoricalLiquidityTracker, ProbabilisticScoringFeeParameters, ProbabilisticScoringDecayParameters, ProbabilisticScorer};
use super::{ChannelLiquidity, HistoricalLiquidityTracker, ProbabilisticScorer, ProbabilisticScoringDecayParameters, ProbabilisticScoringFeeParameters};
use crate::blinded_path::BlindedHop;
use crate::util::config::UserConfig;

use crate::ln::channelmanager;
use crate::ln::msgs::{ChannelAnnouncement, ChannelUpdate, UnsignedChannelAnnouncement, UnsignedChannelUpdate};
use crate::routing::gossip::{EffectiveCapacity, NetworkGraph, NodeId};
use crate::routing::router::{BlindedTail, Path, RouteHop, CandidateRouteHop, PublicHopCandidate};
use crate::routing::scoring::{ChannelUsage, ScoreLookUp, ScoreUpdate};
use crate::routing::scoring::{ChannelLiquidities, ChannelUsage, CombinedScorer, ScoreLookUp, ScoreUpdate};
use crate::util::ser::{ReadableArgs, Writeable};
use crate::util::test_utils::{self, TestLogger};

Expand All @@ -2233,6 +2409,7 @@ mod tests {
use bitcoin::network::Network;
use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
use core::time::Duration;
use std::rc::Rc;
use crate::io;

fn source_privkey() -> SecretKey {
Expand Down Expand Up @@ -3724,6 +3901,59 @@ mod tests {
assert_eq!(scorer.historical_estimated_payment_success_probability(42, &target, amount_msat, &params, false),
Some(0.0));
}

#[test]
fn combined_scorer() {
let logger = TestLogger::new();
let network_graph = network_graph(&logger);
let params = ProbabilisticScoringFeeParameters::default();
let mut scorer = ProbabilisticScorer::new(ProbabilisticScoringDecayParameters::default(), &network_graph, &logger);
scorer.payment_path_failed(&payment_path_for_amount(600), 42, Duration::ZERO);

let mut combined_scorer = CombinedScorer::new(scorer);

// Verify that the combined_scorer has the correct liquidity range after a failed 600 msat payment.
let liquidity_range = combined_scorer.scorer.estimated_channel_liquidity_range(42, &target_node_id());
assert_eq!(liquidity_range.unwrap(), (0, 600));

let source = source_node_id();
let usage = ChannelUsage {
amount_msat: 750,
inflight_htlc_msat: 0,
effective_capacity: EffectiveCapacity::Total { capacity_msat: 1_000, htlc_maximum_msat: 1_000 },
};

{
let network_graph = network_graph.read_only();
let channel = network_graph.channel(42).unwrap();
let (info, _) = channel.as_directed_from(&source).unwrap();
let candidate = CandidateRouteHop::PublicHop(PublicHopCandidate {
info,
short_channel_id: 42,
});

let penalty = combined_scorer.channel_penalty_msat(&candidate, usage, &params);

let mut external_liquidity = ChannelLiquidity::new(Duration::ZERO);
let logger_rc = Rc::new(&logger); // Why necessary and not above for the network graph?
external_liquidity.as_directed_mut(&source_node_id(), &target_node_id(), 1_000).
successful(1000, Duration::ZERO, format_args!("test channel"), logger_rc.as_ref());

let mut external_scores = ChannelLiquidities::new();

external_scores.insert(42, external_liquidity);
combined_scorer.merge(external_scores, Duration::ZERO);

let penalty_after_merge = combined_scorer.channel_penalty_msat(&candidate, usage, &params);

// Since the external source observed a successful payment, the penalty should be lower after the merge.
assert!(penalty_after_merge < penalty);
}

// Verify that after the merge with a successful payment, the liquidity range is increased.
let liquidity_range = combined_scorer.scorer.estimated_channel_liquidity_range(42, &target_node_id());
assert_eq!(liquidity_range.unwrap(), (0, 300));
}
}

#[cfg(ldk_bench)]
Expand Down

0 comments on commit d6caa86

Please sign in to comment.