Skip to content

Commit cd201ea

Browse files
committed
Add utils to persist scorer in BackgroundProcessor
1 parent af31831 commit cd201ea

File tree

3 files changed

+141
-14
lines changed

3 files changed

+141
-14
lines changed

lightning-background-processor/src/lib.rs

Lines changed: 100 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use lightning::ln::channelmanager::ChannelManager;
1818
use lightning::ln::msgs::{ChannelMessageHandler, RoutingMessageHandler};
1919
use lightning::ln::peer_handler::{CustomMessageHandler, PeerManager, SocketDescriptor};
2020
use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
21+
use lightning::routing::scoring::WriteableScore;
2122
use lightning::util::events::{Event, EventHandler, EventsProvider};
2223
use lightning::util::logger::Logger;
2324
use std::sync::Arc;
@@ -81,20 +82,24 @@ const FIRST_NETWORK_PRUNE_TIMER: u64 = 60;
8182
const FIRST_NETWORK_PRUNE_TIMER: u64 = 1;
8283

8384
/// Trait that handles persisting a [`ChannelManager`] and [`NetworkGraph`] to disk.
84-
pub trait Persister<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
85+
pub trait Persister<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref, S: Deref>
8586
where
8687
M::Target: 'static + chain::Watch<Signer>,
8788
T::Target: 'static + BroadcasterInterface,
8889
K::Target: 'static + KeysInterface<Signer = Signer>,
8990
F::Target: 'static + FeeEstimator,
9091
L::Target: 'static + Logger,
92+
S::Target: 'static + WriteableScore<'a>,
9193
{
9294
/// Persist the given [`ChannelManager`] to disk, returning an error if persistence failed
9395
/// (which will cause the [`BackgroundProcessor`] which called this method to exit).
9496
fn persist_manager(&self, channel_manager: &ChannelManager<Signer, M, T, K, F, L>) -> Result<(), std::io::Error>;
9597

9698
/// Persist the given [`NetworkGraph`] to disk, returning an error if persistence failed.
9799
fn persist_graph(&self, network_graph: &NetworkGraph) -> Result<(), std::io::Error>;
100+
101+
/// Persist the given scorer to disk, returning an error if persistence failed.
102+
fn persist_scorer(&self, scorer: &'a S) -> Result<(), std::io::Error>;
98103
}
99104

100105
/// Decorates an [`EventHandler`] with common functionality provided by standard [`EventHandler`]s.
@@ -180,15 +185,16 @@ impl BackgroundProcessor {
180185
CMH: 'static + Deref + Send + Sync,
181186
RMH: 'static + Deref + Send + Sync,
182187
EH: 'static + EventHandler + Send,
183-
PS: 'static + Send + Persister<Signer, CW, T, K, F, L>,
188+
PS: 'static + Send + for<'a> Persister<'a, Signer, CW, T, K, F, L, S>,
184189
M: 'static + Deref<Target = ChainMonitor<Signer, CF, T, F, L, P>> + Send + Sync,
185190
CM: 'static + Deref<Target = ChannelManager<Signer, CW, T, K, F, L>> + Send + Sync,
186191
NG: 'static + Deref<Target = NetGraphMsgHandler<G, CA, L>> + Send + Sync,
187192
UMH: 'static + Deref + Send + Sync,
188193
PM: 'static + Deref<Target = PeerManager<Descriptor, CMH, RMH, L, UMH>> + Send + Sync,
194+
S: 'static + Deref + Send + Sync,
189195
>(
190196
persister: PS, event_handler: EH, chain_monitor: M, channel_manager: CM,
191-
net_graph_msg_handler: Option<NG>, peer_manager: PM, logger: L
197+
net_graph_msg_handler: Option<NG>, peer_manager: PM, logger: L, scorer: S
192198
) -> Self
193199
where
194200
CA::Target: 'static + chain::Access,
@@ -202,6 +208,7 @@ impl BackgroundProcessor {
202208
CMH::Target: 'static + ChannelMessageHandler,
203209
RMH::Target: 'static + RoutingMessageHandler,
204210
UMH::Target: 'static + CustomMessageHandler,
211+
S::Target: for<'a> WriteableScore<'a>,
205212
{
206213
let stop_thread = Arc::new(AtomicBool::new(false));
207214
let stop_thread_clone = stop_thread.clone();
@@ -276,6 +283,9 @@ impl BackgroundProcessor {
276283
if let Err(e) = persister.persist_graph(handler.network_graph()) {
277284
log_error!(logger, "Error: Failed to persist network graph, check your disk and permissions {}", e)
278285
}
286+
if let Err(e) = persister.persist_scorer(&scorer) {
287+
log_error!(logger, "Error: Failed to persist scorer, check your disk and permissions {}", e)
288+
}
279289
last_prune_call = Instant::now();
280290
have_pruned = true;
281291
}
@@ -291,6 +301,10 @@ impl BackgroundProcessor {
291301
if let Some(ref handler) = net_graph_msg_handler {
292302
persister.persist_graph(handler.network_graph())?;
293303
}
304+
305+
// Persist Scorer on exit
306+
persister.persist_scorer(&scorer)?;
307+
294308
Ok(())
295309
});
296310
Self { stop_thread: stop_thread_clone, thread_handle: Some(handle) }
@@ -360,6 +374,7 @@ mod tests {
360374
use lightning::ln::msgs::{ChannelMessageHandler, Init};
361375
use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler};
362376
use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
377+
use lightning::routing::scoring::{WriteableScore};
363378
use lightning::util::config::UserConfig;
364379
use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent};
365380
use lightning::util::logger::Logger;
@@ -414,12 +429,13 @@ mod tests {
414429
struct Persister {
415430
data_dir: String,
416431
graph_error: Option<(std::io::ErrorKind, &'static str)>,
417-
manager_error: Option<(std::io::ErrorKind, &'static str)>
432+
manager_error: Option<(std::io::ErrorKind, &'static str)>,
433+
scorer_error: Option<(std::io::ErrorKind, &'static str)>
418434
}
419435

420436
impl Persister {
421437
fn new(data_dir: String) -> Self {
422-
Self { data_dir, graph_error: None, manager_error: None }
438+
Self { data_dir, graph_error: None, manager_error: None, scorer_error: None }
423439
}
424440

425441
fn with_graph_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
@@ -429,14 +445,19 @@ mod tests {
429445
fn with_manager_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
430446
Self { manager_error: Some((error, message)), ..self }
431447
}
448+
449+
fn with_scorer_error(self, error: std::io::ErrorKind, message: &'static str) -> Self {
450+
Self { scorer_error: Some((error, message)), ..self }
451+
}
432452
}
433453

434-
impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L:Deref> super::Persister<Signer, M, T, K, F, L> for Persister where
454+
impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L:Deref, S:Deref> super::Persister<'a, Signer, M, T, K, F, L, S> for Persister where
435455
M::Target: 'static + chain::Watch<Signer>,
436456
T::Target: 'static + BroadcasterInterface,
437457
K::Target: 'static + KeysInterface<Signer = Signer>,
438458
F::Target: 'static + FeeEstimator,
439459
L::Target: 'static + Logger,
460+
S::Target: 'static + WriteableScore<'a>,
440461
{
441462
fn persist_manager(&self, channel_manager: &ChannelManager<Signer, M, T, K, F, L>) -> Result<(), std::io::Error> {
442463
match self.manager_error {
@@ -451,6 +472,13 @@ mod tests {
451472
Some((error, message)) => Err(std::io::Error::new(error, message)),
452473
}
453474
}
475+
476+
fn persist_scorer(&self, scorer: &'a S) -> Result<(), std::io::Error> {
477+
match self.scorer_error {
478+
None => FilesystemPersister::persist_scorer(self.data_dir.clone(), scorer),
479+
Some((error, message)) => Err(std::io::Error::new(error, message)),
480+
}
481+
}
454482
}
455483

456484
fn get_full_filepath(filepath: String, filename: String) -> String {
@@ -578,7 +606,8 @@ mod tests {
578606
let data_dir = nodes[0].persister.get_data_dir();
579607
let persister = Persister::new(data_dir);
580608
let event_handler = |_: &_| {};
581-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
609+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
610+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), scorer.clone());
582611

583612
macro_rules! check_persisted_data {
584613
($node: expr, $filepath: expr) => {
@@ -604,6 +633,35 @@ mod tests {
604633
}
605634
}
606635

636+
macro_rules! check_mutex_persisted_data {
637+
($node: expr, $filepath: expr) => {
638+
let mut expected_bytes = Vec::new();
639+
loop {
640+
expected_bytes.clear();
641+
match $node.lock() {
642+
Ok(node) => {
643+
match node.write(&mut expected_bytes) {
644+
Ok(_) => {
645+
match std::fs::read($filepath) {
646+
Ok(bytes) => {
647+
if bytes == expected_bytes {
648+
break
649+
} else {
650+
continue
651+
}
652+
},
653+
Err(_) => continue
654+
}
655+
}
656+
Err(_) => continue
657+
}
658+
},
659+
Err(e) => panic!("Unexpected error: {}", e)
660+
}
661+
}
662+
}
663+
}
664+
607665
// Check that the initial channel manager data is persisted as expected.
608666
let filepath = get_full_filepath("test_background_processor_persister_0".to_string(), "manager".to_string());
609667
check_persisted_data!(nodes[0].node, filepath.clone());
@@ -628,6 +686,10 @@ mod tests {
628686
check_persisted_data!(network_graph, filepath.clone());
629687
}
630688

689+
// Check scorer is persisted
690+
let filepath = get_full_filepath("test_background_processor_persister_0".to_string(), "scorer".to_string());
691+
check_mutex_persisted_data!(scorer, filepath.clone());
692+
631693
assert!(bg_processor.stop().is_ok());
632694
}
633695

@@ -639,7 +701,8 @@ mod tests {
639701
let data_dir = nodes[0].persister.get_data_dir();
640702
let persister = Persister::new(data_dir);
641703
let event_handler = |_: &_| {};
642-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
704+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
705+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), scorer);
643706
loop {
644707
let log_entries = nodes[0].logger.lines.lock().unwrap();
645708
let desired_log = "Calling ChannelManager's timer_tick_occurred".to_string();
@@ -662,7 +725,8 @@ mod tests {
662725
let data_dir = nodes[0].persister.get_data_dir();
663726
let persister = Persister::new(data_dir).with_manager_error(std::io::ErrorKind::Other, "test");
664727
let event_handler = |_: &_| {};
665-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
728+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
729+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), scorer);
666730
match bg_processor.join() {
667731
Ok(_) => panic!("Expected error persisting manager"),
668732
Err(e) => {
@@ -679,7 +743,8 @@ mod tests {
679743
let data_dir = nodes[0].persister.get_data_dir();
680744
let persister = Persister::new(data_dir).with_graph_error(std::io::ErrorKind::Other, "test");
681745
let event_handler = |_: &_| {};
682-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
746+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
747+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), scorer);
683748

684749
match bg_processor.stop() {
685750
Ok(_) => panic!("Expected error persisting network graph"),
@@ -690,6 +755,25 @@ mod tests {
690755
}
691756
}
692757

758+
#[test]
759+
fn test_scorer_persist_error() {
760+
// Test that if we encounter an error during scorer persistence, an error gets returned.
761+
let nodes = create_nodes(2, "test_persist_scorer_error".to_string());
762+
let data_dir = nodes[0].persister.get_data_dir();
763+
let persister = Persister::new(data_dir).with_scorer_error(std::io::ErrorKind::Other, "test");
764+
let event_handler = |_: &_| {};
765+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
766+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), scorer);
767+
768+
match bg_processor.stop() {
769+
Ok(_) => panic!("Expected error persisting scorer"),
770+
Err(e) => {
771+
assert_eq!(e.kind(), std::io::ErrorKind::Other);
772+
assert_eq!(e.get_ref().unwrap().to_string(), "test");
773+
},
774+
}
775+
}
776+
693777
#[test]
694778
fn test_background_event_handling() {
695779
let mut nodes = create_nodes(2, "test_background_event_handling".to_string());
@@ -702,7 +786,8 @@ mod tests {
702786
let event_handler = move |event: &Event| {
703787
sender.send(handle_funding_generation_ready!(event, channel_value)).unwrap();
704788
};
705-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
789+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
790+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), scorer);
706791

707792
// Open a channel and check that the FundingGenerationReady event was handled.
708793
begin_open_channel!(nodes[0], nodes[1], channel_value);
@@ -726,7 +811,8 @@ mod tests {
726811
// Set up a background event handler for SpendableOutputs events.
727812
let (sender, receiver) = std::sync::mpsc::sync_channel(1);
728813
let event_handler = move |event: &Event| sender.send(event.clone()).unwrap();
729-
let bg_processor = BackgroundProcessor::start(Persister::new(data_dir), event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
814+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
815+
let bg_processor = BackgroundProcessor::start(Persister::new(data_dir), event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), scorer);
730816

731817
// Force close the channel and check that the SpendableOutputs event was handled.
732818
nodes[0].node.force_close_channel(&nodes[0].node.list_channels()[0].channel_id).unwrap();
@@ -757,7 +843,8 @@ mod tests {
757843
let router = DefaultRouter::new(Arc::clone(&nodes[0].network_graph), Arc::clone(&nodes[0].logger), random_seed_bytes);
758844
let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2)));
759845
let event_handler = Arc::clone(&invoice_payer);
760-
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone());
846+
let scorer = Arc::new(Mutex::new(test_utils::TestScorer::with_penalty(0)));
847+
let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone(), scorer);
761848
assert!(bg_processor.stop().is_ok());
762849
}
763850
}

lightning-persister/src/lib.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ extern crate libc;
1717
use bitcoin::hash_types::{BlockHash, Txid};
1818
use bitcoin::hashes::hex::{FromHex, ToHex};
1919
use lightning::routing::network_graph::NetworkGraph;
20+
use lightning::routing::scoring::{WriteableScore};
21+
use util::get_full_filepath;
2022
use crate::util::DiskWriteable;
2123
use lightning::chain;
2224
use lightning::chain::chaininterface::{BroadcasterInterface, FeeEstimator};
@@ -28,7 +30,7 @@ use lightning::ln::channelmanager::ChannelManager;
2830
use lightning::util::logger::Logger;
2931
use lightning::util::ser::{ReadableArgs, Writeable};
3032
use std::fs;
31-
use std::io::{Cursor, Error};
33+
use std::io::{Cursor, Error, BufWriter};
3234
use std::ops::Deref;
3335
use std::path::{Path, PathBuf};
3436

@@ -117,6 +119,27 @@ impl FilesystemPersister {
117119
util::write_to_file(path, "network_graph".to_string(), network_graph)
118120
}
119121

122+
/// Write the provided scorer to the path provided at `FilesystemPersister`
123+
/// initialization, within a file called "scorer"
124+
pub fn persist_scorer<'a, S: Deref>(data_dir: String, scorer: &'a S) -> Result<(), std::io::Error>
125+
where
126+
S::Target: WriteableScore<'a>
127+
{
128+
let path = PathBuf::from(data_dir);
129+
fs::create_dir_all(path.clone())?;
130+
let filename_with_path = get_full_filepath(path.clone(), "scorer".to_string());
131+
let tmp_filename = format!("{}.tmp", filename_with_path.clone());
132+
133+
let file = fs::File::create(&tmp_filename)?;
134+
let write_res = scorer.write(&mut BufWriter::new(file));
135+
if let Err(e) = write_res.and_then(|_| fs::rename(&tmp_filename, filename_with_path)) {
136+
let _ = fs::remove_file(&tmp_filename);
137+
Err(e)
138+
} else {
139+
Ok(())
140+
}
141+
}
142+
120143
/// Read `ChannelMonitor`s from disk.
121144
pub fn read_channelmonitors<Signer: Sign, K: Deref> (
122145
&self, keys_manager: K

lightning/src/routing/scoring.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,23 @@ pub trait LockableScore<'a> {
135135
fn lock(&'a self) -> Self::Locked;
136136
}
137137

138+
/// Refers to a scorer that is accessible under lock and also writeable to disk
139+
///
140+
/// We need this trait to be able to pass in a scorer to `lightning-background-processor` that will enable us to
141+
/// use the Persister to persist it.
142+
pub trait WriteableScore<'a>: LockableScore<'a> {
143+
/// Locks the LockableScore and writes it to disk
144+
fn write<W: Writer>(&'a self, writer: &mut W) -> Result<(), io::Error>;
145+
}
146+
147+
impl<'a, U: Writeable, T: LockableScore<'a>> WriteableScore<'a> for T
148+
where T::Locked: DerefMut<Target=U>
149+
{
150+
fn write<W: Writer>(&'a self, writer: &mut W) -> Result<(), io::Error> {
151+
self.lock().write(writer)
152+
}
153+
}
154+
138155
/// (C-not exported)
139156
impl<'a, T: 'a + Score> LockableScore<'a> for Mutex<T> {
140157
type Locked = MutexGuard<'a, T>;

0 commit comments

Comments
 (0)