Skip to content

Commit

Permalink
Add methods to import, export, and reset the scorer
Browse files Browse the repository at this point in the history
  • Loading branch information
amackillop committed Dec 30, 2024
1 parent 76623f7 commit 1b3e9fa
Showing 1 changed file with 50 additions and 6 deletions.
56 changes: 50 additions & 6 deletions src/prober.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,39 @@ use std::{
use crate::{
error::Error,
logger::LdkNodeLogger,
types::{ChannelManager, Router, Scorer},
types::{ChannelManager, Graph, Router, Scorer},
};
use bitcoin::secp256k1::PublicKey;
use lightning::{
io::Cursor,
ln::{channel_state::ChannelDetails, channelmanager::PaymentId, PaymentHash},
log_error,
routing::{
router::{Path, PaymentParameters, Route, RouteParameters, Router as _},
scoring::ScoreUpdate as _,
scoring::{ProbabilisticScorer, ProbabilisticScoringDecayParameters, ScoreUpdate as _},
},
util::{
logger::Logger as _,
ser::{ReadableArgs as _, Writeable},
},
util::logger::Logger as _,
};

/// The Prober can be used to send probes to a destination node outside of regular payment flows.
pub struct Prober {
channel_manager: Arc<ChannelManager>,
router: Arc<Router>,
scorer: Arc<Mutex<Scorer>>,
network_graph: Arc<Graph>,
logger: Arc<LdkNodeLogger>,
node_id: PublicKey,
}

impl Prober {
pub(crate) fn new(
channel_manager: Arc<ChannelManager>, router: Arc<Router>, scorer: Arc<Mutex<Scorer>>,
logger: Arc<LdkNodeLogger>, node_id: PublicKey,
network_graph: Arc<Graph>, logger: Arc<LdkNodeLogger>, node_id: PublicKey,
) -> Self {
Self { channel_manager, router, scorer, logger, node_id }
Self { channel_manager, router, scorer, network_graph, logger, node_id }
}

/// Find a route from the node to a given destination on the network.
Expand All @@ -59,7 +64,10 @@ impl Prober {
let inflight_htlcs = self.channel_manager.compute_inflight_htlcs();
self.router
.find_route(&self.node_id, &route_params, Some(&first_hops[..]), inflight_htlcs)
.map_err(|_| Error::RouteNotFound)
.map_err(|e| {
log_error!(self.logger, "Failed to find route: {e:?}");
Error::RouteNotFound
})
}

/// Send a probe along the given path returning the payment hash and id of the fake payment.
Expand All @@ -81,4 +89,40 @@ impl Prober {
scorer.probe_successful(path, duration_since_epoch);
}
}

/// Export the scorer
pub fn export_scorer(&self) -> Result<Vec<u8>, std::io::Error> {
let scorer = self.scorer.lock().expect("Lock poisoned");
let mut writer = Vec::new();
scorer.write(&mut writer).map_err(|e| {
log_error!(self.logger, "Failed to serialize scorer: {}", e);
std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to serialize Scorer")
})?;
Ok(writer)
}

/// Import a new scorer
pub fn import_scorer(
&self, scorer_bytes: Vec<u8>, decay_params: Option<ProbabilisticScoringDecayParameters>,
) -> Result<(), std::io::Error> {
let params = decay_params.unwrap_or_default();
let mut reader = Cursor::new(scorer_bytes);
let args = (params, self.network_graph.clone(), self.logger.clone());
let new_scorer = ProbabilisticScorer::read(&mut reader, args).map_err(|e| {
log_error!(self.logger, "Failed to deserialize scorer: {}", e);
std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize Scorer")
})?;
let mut scorer = self.scorer.lock().expect("Lock poisoned");
*scorer = new_scorer;
Ok(())
}

/// Reset the Scorer
pub fn reset_scorer(&self) {
let params = ProbabilisticScoringDecayParameters::default();
let new_scorer =
ProbabilisticScorer::new(params, self.network_graph.clone(), self.logger.clone());
let mut scorer = self.scorer.lock().expect("Lock poisoned");
*scorer = new_scorer;
}
}

0 comments on commit 1b3e9fa

Please sign in to comment.