@@ -18,6 +18,7 @@ use lightning::ln::channelmanager::ChannelManager;
1818use lightning:: ln:: msgs:: { ChannelMessageHandler , RoutingMessageHandler } ;
1919use lightning:: ln:: peer_handler:: { CustomMessageHandler , PeerManager , SocketDescriptor } ;
2020use lightning:: routing:: network_graph:: { NetworkGraph , NetGraphMsgHandler } ;
21+ use lightning:: routing:: scoring:: WriteableScore ;
2122use lightning:: util:: events:: { Event , EventHandler , EventsProvider } ;
2223use lightning:: util:: logger:: Logger ;
2324use std:: sync:: Arc ;
@@ -81,20 +82,24 @@ const FIRST_NETWORK_PRUNE_TIMER: u64 = 60;
8182const FIRST_NETWORK_PRUNE_TIMER : u64 = 1 ;
8283
8384/// Trait that handles persisting a [`ChannelManager`] and [`NetworkGraph`] to disk.
84- pub trait Persister < Signer : Sign , M : Deref , T : Deref , K : Deref , F : Deref , L : Deref >
85+ pub trait Persister < ' a , Signer : Sign , M : Deref , T : Deref , K : Deref , F : Deref , L : Deref , S : Deref >
8586where
8687 M :: Target : ' static + chain:: Watch < Signer > ,
8788 T :: Target : ' static + BroadcasterInterface ,
8889 K :: Target : ' static + KeysInterface < Signer = Signer > ,
8990 F :: Target : ' static + FeeEstimator ,
9091 L :: Target : ' static + Logger ,
92+ S :: Target : ' static + WriteableScore < ' a > ,
9193{
9294 /// Persist the given [`ChannelManager`] to disk, returning an error if persistence failed
9395 /// (which will cause the [`BackgroundProcessor`] which called this method to exit).
9496 fn persist_manager ( & self , channel_manager : & ChannelManager < Signer , M , T , K , F , L > ) -> Result < ( ) , std:: io:: Error > ;
9597
9698 /// Persist the given [`NetworkGraph`] to disk, returning an error if persistence failed.
9799 fn persist_graph ( & self , network_graph : & NetworkGraph ) -> Result < ( ) , std:: io:: Error > ;
100+
101+ /// Persist the given scorer to disk, returning an error if persistence failed.
102+ fn persist_scorer ( & self , scorer : & ' a S ) -> Result < ( ) , std:: io:: Error > ;
98103}
99104
100105/// Decorates an [`EventHandler`] with common functionality provided by standard [`EventHandler`]s.
@@ -180,15 +185,16 @@ impl BackgroundProcessor {
180185 CMH : ' static + Deref + Send + Sync ,
181186 RMH : ' static + Deref + Send + Sync ,
182187 EH : ' static + EventHandler + Send ,
183- PS : ' static + Send + Persister < Signer , CW , T , K , F , L > ,
188+ PS : ' static + Send + for < ' a > Persister < ' a , Signer , CW , T , K , F , L , S > ,
184189 M : ' static + Deref < Target = ChainMonitor < Signer , CF , T , F , L , P > > + Send + Sync ,
185190 CM : ' static + Deref < Target = ChannelManager < Signer , CW , T , K , F , L > > + Send + Sync ,
186191 NG : ' static + Deref < Target = NetGraphMsgHandler < G , CA , L > > + Send + Sync ,
187192 UMH : ' static + Deref + Send + Sync ,
188193 PM : ' static + Deref < Target = PeerManager < Descriptor , CMH , RMH , L , UMH > > + Send + Sync ,
194+ S : ' static + Deref + Send + Sync ,
189195 > (
190196 persister : PS , event_handler : EH , chain_monitor : M , channel_manager : CM ,
191- net_graph_msg_handler : Option < NG > , peer_manager : PM , logger : L
197+ net_graph_msg_handler : Option < NG > , peer_manager : PM , logger : L , scorer : S
192198 ) -> Self
193199 where
194200 CA :: Target : ' static + chain:: Access ,
@@ -202,6 +208,7 @@ impl BackgroundProcessor {
202208 CMH :: Target : ' static + ChannelMessageHandler ,
203209 RMH :: Target : ' static + RoutingMessageHandler ,
204210 UMH :: Target : ' static + CustomMessageHandler ,
211+ S :: Target : for < ' a > WriteableScore < ' a > ,
205212 {
206213 let stop_thread = Arc :: new ( AtomicBool :: new ( false ) ) ;
207214 let stop_thread_clone = stop_thread. clone ( ) ;
@@ -276,6 +283,9 @@ impl BackgroundProcessor {
276283 if let Err ( e) = persister. persist_graph ( handler. network_graph ( ) ) {
277284 log_error ! ( logger, "Error: Failed to persist network graph, check your disk and permissions {}" , e)
278285 }
286+ if let Err ( e) = persister. persist_scorer ( & scorer) {
287+ log_error ! ( logger, "Error: Failed to persist scorer, check your disk and permissions {}" , e)
288+ }
279289 last_prune_call = Instant :: now ( ) ;
280290 have_pruned = true ;
281291 }
@@ -291,6 +301,10 @@ impl BackgroundProcessor {
291301 if let Some ( ref handler) = net_graph_msg_handler {
292302 persister. persist_graph ( handler. network_graph ( ) ) ?;
293303 }
304+
305+ // Persist Scorer on exit
306+ persister. persist_scorer ( & scorer) ?;
307+
294308 Ok ( ( ) )
295309 } ) ;
296310 Self { stop_thread : stop_thread_clone, thread_handle : Some ( handle) }
@@ -360,6 +374,7 @@ mod tests {
360374 use lightning:: ln:: msgs:: { ChannelMessageHandler , Init } ;
361375 use lightning:: ln:: peer_handler:: { PeerManager , MessageHandler , SocketDescriptor , IgnoringMessageHandler } ;
362376 use lightning:: routing:: network_graph:: { NetworkGraph , NetGraphMsgHandler } ;
377+ use lightning:: routing:: scoring:: { WriteableScore } ;
363378 use lightning:: util:: config:: UserConfig ;
364379 use lightning:: util:: events:: { Event , MessageSendEventsProvider , MessageSendEvent } ;
365380 use lightning:: util:: logger:: Logger ;
@@ -414,12 +429,13 @@ mod tests {
414429 struct Persister {
415430 data_dir : String ,
416431 graph_error : Option < ( std:: io:: ErrorKind , & ' static str ) > ,
417- manager_error : Option < ( std:: io:: ErrorKind , & ' static str ) >
432+ manager_error : Option < ( std:: io:: ErrorKind , & ' static str ) > ,
433+ scorer_error : Option < ( std:: io:: ErrorKind , & ' static str ) >
418434 }
419435
420436 impl Persister {
421437 fn new ( data_dir : String ) -> Self {
422- Self { data_dir, graph_error : None , manager_error : None }
438+ Self { data_dir, graph_error : None , manager_error : None , scorer_error : None }
423439 }
424440
425441 fn with_graph_error ( self , error : std:: io:: ErrorKind , message : & ' static str ) -> Self {
@@ -429,14 +445,19 @@ mod tests {
429445 fn with_manager_error ( self , error : std:: io:: ErrorKind , message : & ' static str ) -> Self {
430446 Self { manager_error : Some ( ( error, message) ) , ..self }
431447 }
448+
449+ fn with_scorer_error ( self , error : std:: io:: ErrorKind , message : & ' static str ) -> Self {
450+ Self { scorer_error : Some ( ( error, message) ) , ..self }
451+ }
432452 }
433453
434- impl < Signer : Sign , M : Deref , T : Deref , K : Deref , F : Deref , L : Deref > super :: Persister < Signer , M , T , K , F , L > for Persister where
454+ impl < ' a , Signer : Sign , M : Deref , T : Deref , K : Deref , F : Deref , L : Deref , S : Deref > super :: Persister < ' a , Signer , M , T , K , F , L , S > for Persister where
435455 M :: Target : ' static + chain:: Watch < Signer > ,
436456 T :: Target : ' static + BroadcasterInterface ,
437457 K :: Target : ' static + KeysInterface < Signer = Signer > ,
438458 F :: Target : ' static + FeeEstimator ,
439459 L :: Target : ' static + Logger ,
460+ S :: Target : ' static + WriteableScore < ' a > ,
440461 {
441462 fn persist_manager ( & self , channel_manager : & ChannelManager < Signer , M , T , K , F , L > ) -> Result < ( ) , std:: io:: Error > {
442463 match self . manager_error {
@@ -451,6 +472,13 @@ mod tests {
451472 Some ( ( error, message) ) => Err ( std:: io:: Error :: new ( error, message) ) ,
452473 }
453474 }
475+
476+ fn persist_scorer ( & self , scorer : & ' a S ) -> Result < ( ) , std:: io:: Error > {
477+ match self . scorer_error {
478+ None => FilesystemPersister :: persist_scorer ( self . data_dir . clone ( ) , scorer) ,
479+ Some ( ( error, message) ) => Err ( std:: io:: Error :: new ( error, message) ) ,
480+ }
481+ }
454482 }
455483
456484 fn get_full_filepath ( filepath : String , filename : String ) -> String {
@@ -578,7 +606,8 @@ mod tests {
578606 let data_dir = nodes[ 0 ] . persister . get_data_dir ( ) ;
579607 let persister = Persister :: new ( data_dir) ;
580608 let event_handler = |_: & _ | { } ;
581- let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) ) ;
609+ let scorer = Arc :: new ( Mutex :: new ( test_utils:: TestScorer :: with_penalty ( 0 ) ) ) ;
610+ let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) , scorer. clone ( ) ) ;
582611
583612 macro_rules! check_persisted_data {
584613 ( $node: expr, $filepath: expr) => {
@@ -604,6 +633,35 @@ mod tests {
604633 }
605634 }
606635
636+ macro_rules! check_mutex_persisted_data {
637+ ( $node: expr, $filepath: expr) => {
638+ let mut expected_bytes = Vec :: new( ) ;
639+ loop {
640+ expected_bytes. clear( ) ;
641+ match $node. lock( ) {
642+ Ok ( node) => {
643+ match node. write( & mut expected_bytes) {
644+ Ok ( _) => {
645+ match std:: fs:: read( $filepath) {
646+ Ok ( bytes) => {
647+ if bytes == expected_bytes {
648+ break
649+ } else {
650+ continue
651+ }
652+ } ,
653+ Err ( _) => continue
654+ }
655+ }
656+ Err ( _) => continue
657+ }
658+ } ,
659+ Err ( e) => panic!( "Unexpected error: {}" , e)
660+ }
661+ }
662+ }
663+ }
664+
607665 // Check that the initial channel manager data is persisted as expected.
608666 let filepath = get_full_filepath ( "test_background_processor_persister_0" . to_string ( ) , "manager" . to_string ( ) ) ;
609667 check_persisted_data ! ( nodes[ 0 ] . node, filepath. clone( ) ) ;
@@ -628,6 +686,10 @@ mod tests {
628686 check_persisted_data ! ( network_graph, filepath. clone( ) ) ;
629687 }
630688
689+ // Check scorer is persisted
690+ let filepath = get_full_filepath ( "test_background_processor_persister_0" . to_string ( ) , "scorer" . to_string ( ) ) ;
691+ check_mutex_persisted_data ! ( scorer, filepath. clone( ) ) ;
692+
631693 assert ! ( bg_processor. stop( ) . is_ok( ) ) ;
632694 }
633695
@@ -639,7 +701,8 @@ mod tests {
639701 let data_dir = nodes[ 0 ] . persister . get_data_dir ( ) ;
640702 let persister = Persister :: new ( data_dir) ;
641703 let event_handler = |_: & _ | { } ;
642- let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) ) ;
704+ let scorer = Arc :: new ( Mutex :: new ( test_utils:: TestScorer :: with_penalty ( 0 ) ) ) ;
705+ let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) , scorer) ;
643706 loop {
644707 let log_entries = nodes[ 0 ] . logger . lines . lock ( ) . unwrap ( ) ;
645708 let desired_log = "Calling ChannelManager's timer_tick_occurred" . to_string ( ) ;
@@ -662,7 +725,8 @@ mod tests {
662725 let data_dir = nodes[ 0 ] . persister . get_data_dir ( ) ;
663726 let persister = Persister :: new ( data_dir) . with_manager_error ( std:: io:: ErrorKind :: Other , "test" ) ;
664727 let event_handler = |_: & _ | { } ;
665- let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) ) ;
728+ let scorer = Arc :: new ( Mutex :: new ( test_utils:: TestScorer :: with_penalty ( 0 ) ) ) ;
729+ let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) , scorer) ;
666730 match bg_processor. join ( ) {
667731 Ok ( _) => panic ! ( "Expected error persisting manager" ) ,
668732 Err ( e) => {
@@ -679,7 +743,8 @@ mod tests {
679743 let data_dir = nodes[ 0 ] . persister . get_data_dir ( ) ;
680744 let persister = Persister :: new ( data_dir) . with_graph_error ( std:: io:: ErrorKind :: Other , "test" ) ;
681745 let event_handler = |_: & _ | { } ;
682- let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) ) ;
746+ let scorer = Arc :: new ( Mutex :: new ( test_utils:: TestScorer :: with_penalty ( 0 ) ) ) ;
747+ let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) , scorer) ;
683748
684749 match bg_processor. stop ( ) {
685750 Ok ( _) => panic ! ( "Expected error persisting network graph" ) ,
@@ -690,6 +755,25 @@ mod tests {
690755 }
691756 }
692757
758+ #[ test]
759+ fn test_scorer_persist_error ( ) {
760+ // Test that if we encounter an error during scorer persistence, an error gets returned.
761+ let nodes = create_nodes ( 2 , "test_persist_scorer_error" . to_string ( ) ) ;
762+ let data_dir = nodes[ 0 ] . persister . get_data_dir ( ) ;
763+ let persister = Persister :: new ( data_dir) . with_scorer_error ( std:: io:: ErrorKind :: Other , "test" ) ;
764+ let event_handler = |_: & _ | { } ;
765+ let scorer = Arc :: new ( Mutex :: new ( test_utils:: TestScorer :: with_penalty ( 0 ) ) ) ;
766+ let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) , scorer) ;
767+
768+ match bg_processor. stop ( ) {
769+ Ok ( _) => panic ! ( "Expected error persisting scorer" ) ,
770+ Err ( e) => {
771+ assert_eq ! ( e. kind( ) , std:: io:: ErrorKind :: Other ) ;
772+ assert_eq ! ( e. get_ref( ) . unwrap( ) . to_string( ) , "test" ) ;
773+ } ,
774+ }
775+ }
776+
693777 #[ test]
694778 fn test_background_event_handling ( ) {
695779 let mut nodes = create_nodes ( 2 , "test_background_event_handling" . to_string ( ) ) ;
@@ -702,7 +786,8 @@ mod tests {
702786 let event_handler = move |event : & Event | {
703787 sender. send ( handle_funding_generation_ready ! ( event, channel_value) ) . unwrap ( ) ;
704788 } ;
705- let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) ) ;
789+ let scorer = Arc :: new ( Mutex :: new ( test_utils:: TestScorer :: with_penalty ( 0 ) ) ) ;
790+ let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) , scorer) ;
706791
707792 // Open a channel and check that the FundingGenerationReady event was handled.
708793 begin_open_channel ! ( nodes[ 0 ] , nodes[ 1 ] , channel_value) ;
@@ -726,7 +811,8 @@ mod tests {
726811 // Set up a background event handler for SpendableOutputs events.
727812 let ( sender, receiver) = std:: sync:: mpsc:: sync_channel ( 1 ) ;
728813 let event_handler = move |event : & Event | sender. send ( event. clone ( ) ) . unwrap ( ) ;
729- let bg_processor = BackgroundProcessor :: start ( Persister :: new ( data_dir) , event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) ) ;
814+ let scorer = Arc :: new ( Mutex :: new ( test_utils:: TestScorer :: with_penalty ( 0 ) ) ) ;
815+ let bg_processor = BackgroundProcessor :: start ( Persister :: new ( data_dir) , event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) , scorer) ;
730816
731817 // Force close the channel and check that the SpendableOutputs event was handled.
732818 nodes[ 0 ] . node . force_close_channel ( & nodes[ 0 ] . node . list_channels ( ) [ 0 ] . channel_id ) . unwrap ( ) ;
@@ -757,7 +843,8 @@ mod tests {
757843 let router = DefaultRouter :: new ( Arc :: clone ( & nodes[ 0 ] . network_graph ) , Arc :: clone ( & nodes[ 0 ] . logger ) , random_seed_bytes) ;
758844 let invoice_payer = Arc :: new ( InvoicePayer :: new ( Arc :: clone ( & nodes[ 0 ] . node ) , router, scorer, Arc :: clone ( & nodes[ 0 ] . logger ) , |_: & _ | { } , RetryAttempts ( 2 ) ) ) ;
759845 let event_handler = Arc :: clone ( & invoice_payer) ;
760- let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) ) ;
846+ let scorer = Arc :: new ( Mutex :: new ( test_utils:: TestScorer :: with_penalty ( 0 ) ) ) ;
847+ let bg_processor = BackgroundProcessor :: start ( persister, event_handler, nodes[ 0 ] . chain_monitor . clone ( ) , nodes[ 0 ] . node . clone ( ) , nodes[ 0 ] . net_graph_msg_handler . clone ( ) , nodes[ 0 ] . peer_manager . clone ( ) , nodes[ 0 ] . logger . clone ( ) , scorer) ;
761848 assert ! ( bg_processor. stop( ) . is_ok( ) ) ;
762849 }
763850}
0 commit comments