@@ -18,7 +18,7 @@ use chain::channelmonitor::MonitorEvent;
1818use chain:: transaction:: OutPoint ;
1919use chain:: keysinterface;
2020use ln:: features:: { ChannelFeatures , InitFeatures } ;
21- use ln:: msgs;
21+ use ln:: { msgs, wire } ;
2222use ln:: msgs:: OptionalField ;
2323use ln:: script:: ShutdownScript ;
2424use routing:: scoring:: FixedPenaltyScorer ;
@@ -249,37 +249,106 @@ impl chaininterface::BroadcasterInterface for TestBroadcaster {
249249
250250pub struct TestChannelMessageHandler {
251251 pub pending_events : Mutex < Vec < events:: MessageSendEvent > > ,
252+ expected_recv_msgs : Mutex < Option < Vec < wire:: Message < ( ) > > > > ,
252253}
253254
254255impl TestChannelMessageHandler {
255256 pub fn new ( ) -> Self {
256257 TestChannelMessageHandler {
257258 pending_events : Mutex :: new ( Vec :: new ( ) ) ,
259+ expected_recv_msgs : Mutex :: new ( None ) ,
260+ }
261+ }
262+
263+ #[ cfg( test) ]
264+ pub ( crate ) fn expect_receive_msg ( & self , ev : wire:: Message < ( ) > ) {
265+ let mut expected_msgs = self . expected_recv_msgs . lock ( ) . unwrap ( ) ;
266+ if expected_msgs. is_none ( ) { * expected_msgs = Some ( Vec :: new ( ) ) ; }
267+ expected_msgs. as_mut ( ) . unwrap ( ) . push ( ev) ;
268+ }
269+
270+ fn received_msg ( & self , ev : wire:: Message < ( ) > ) {
271+ let mut msgs = self . expected_recv_msgs . lock ( ) . unwrap ( ) ;
272+ if msgs. is_none ( ) { return ; }
273+ assert ! ( !msgs. as_ref( ) . unwrap( ) . is_empty( ) , "Received message when we weren't expecting one" ) ;
274+ #[ cfg( test) ]
275+ assert_eq ! ( msgs. as_ref( ) . unwrap( ) [ 0 ] , ev) ;
276+ msgs. as_mut ( ) . unwrap ( ) . remove ( 0 ) ;
277+ }
278+ }
279+
280+ impl Drop for TestChannelMessageHandler {
281+ fn drop ( & mut self ) {
282+ let l = self . expected_recv_msgs . lock ( ) . unwrap ( ) ;
283+ #[ cfg( feature = "std" ) ]
284+ {
285+ if !std:: thread:: panicking ( ) {
286+ assert ! ( l. is_none( ) || l. as_ref( ) . unwrap( ) . is_empty( ) ) ;
287+ }
258288 }
259289 }
260290}
261291
262292impl msgs:: ChannelMessageHandler for TestChannelMessageHandler {
263- fn handle_open_channel ( & self , _their_node_id : & PublicKey , _their_features : InitFeatures , _msg : & msgs:: OpenChannel ) { }
264- fn handle_accept_channel ( & self , _their_node_id : & PublicKey , _their_features : InitFeatures , _msg : & msgs:: AcceptChannel ) { }
265- fn handle_funding_created ( & self , _their_node_id : & PublicKey , _msg : & msgs:: FundingCreated ) { }
266- fn handle_funding_signed ( & self , _their_node_id : & PublicKey , _msg : & msgs:: FundingSigned ) { }
267- fn handle_funding_locked ( & self , _their_node_id : & PublicKey , _msg : & msgs:: FundingLocked ) { }
268- fn handle_shutdown ( & self , _their_node_id : & PublicKey , _their_features : & InitFeatures , _msg : & msgs:: Shutdown ) { }
269- fn handle_closing_signed ( & self , _their_node_id : & PublicKey , _msg : & msgs:: ClosingSigned ) { }
270- fn handle_update_add_htlc ( & self , _their_node_id : & PublicKey , _msg : & msgs:: UpdateAddHTLC ) { }
271- fn handle_update_fulfill_htlc ( & self , _their_node_id : & PublicKey , _msg : & msgs:: UpdateFulfillHTLC ) { }
272- fn handle_update_fail_htlc ( & self , _their_node_id : & PublicKey , _msg : & msgs:: UpdateFailHTLC ) { }
273- fn handle_update_fail_malformed_htlc ( & self , _their_node_id : & PublicKey , _msg : & msgs:: UpdateFailMalformedHTLC ) { }
274- fn handle_commitment_signed ( & self , _their_node_id : & PublicKey , _msg : & msgs:: CommitmentSigned ) { }
275- fn handle_revoke_and_ack ( & self , _their_node_id : & PublicKey , _msg : & msgs:: RevokeAndACK ) { }
276- fn handle_update_fee ( & self , _their_node_id : & PublicKey , _msg : & msgs:: UpdateFee ) { }
277- fn handle_channel_update ( & self , _their_node_id : & PublicKey , _msg : & msgs:: ChannelUpdate ) { }
278- fn handle_announcement_signatures ( & self , _their_node_id : & PublicKey , _msg : & msgs:: AnnouncementSignatures ) { }
279- fn handle_channel_reestablish ( & self , _their_node_id : & PublicKey , _msg : & msgs:: ChannelReestablish ) { }
293+ fn handle_open_channel ( & self , _their_node_id : & PublicKey , _their_features : InitFeatures , msg : & msgs:: OpenChannel ) {
294+ self . received_msg ( wire:: Message :: OpenChannel ( msg. clone ( ) ) ) ;
295+ }
296+ fn handle_accept_channel ( & self , _their_node_id : & PublicKey , _their_features : InitFeatures , msg : & msgs:: AcceptChannel ) {
297+ self . received_msg ( wire:: Message :: AcceptChannel ( msg. clone ( ) ) ) ;
298+ }
299+ fn handle_funding_created ( & self , _their_node_id : & PublicKey , msg : & msgs:: FundingCreated ) {
300+ self . received_msg ( wire:: Message :: FundingCreated ( msg. clone ( ) ) ) ;
301+ }
302+ fn handle_funding_signed ( & self , _their_node_id : & PublicKey , msg : & msgs:: FundingSigned ) {
303+ self . received_msg ( wire:: Message :: FundingSigned ( msg. clone ( ) ) ) ;
304+ }
305+ fn handle_funding_locked ( & self , _their_node_id : & PublicKey , msg : & msgs:: FundingLocked ) {
306+ self . received_msg ( wire:: Message :: FundingLocked ( msg. clone ( ) ) ) ;
307+ }
308+ fn handle_shutdown ( & self , _their_node_id : & PublicKey , _their_features : & InitFeatures , msg : & msgs:: Shutdown ) {
309+ self . received_msg ( wire:: Message :: Shutdown ( msg. clone ( ) ) ) ;
310+ }
311+ fn handle_closing_signed ( & self , _their_node_id : & PublicKey , msg : & msgs:: ClosingSigned ) {
312+ self . received_msg ( wire:: Message :: ClosingSigned ( msg. clone ( ) ) ) ;
313+ }
314+ fn handle_update_add_htlc ( & self , _their_node_id : & PublicKey , msg : & msgs:: UpdateAddHTLC ) {
315+ self . received_msg ( wire:: Message :: UpdateAddHTLC ( msg. clone ( ) ) ) ;
316+ }
317+ fn handle_update_fulfill_htlc ( & self , _their_node_id : & PublicKey , msg : & msgs:: UpdateFulfillHTLC ) {
318+ self . received_msg ( wire:: Message :: UpdateFulfillHTLC ( msg. clone ( ) ) ) ;
319+ }
320+ fn handle_update_fail_htlc ( & self , _their_node_id : & PublicKey , msg : & msgs:: UpdateFailHTLC ) {
321+ self . received_msg ( wire:: Message :: UpdateFailHTLC ( msg. clone ( ) ) ) ;
322+ }
323+ fn handle_update_fail_malformed_htlc ( & self , _their_node_id : & PublicKey , msg : & msgs:: UpdateFailMalformedHTLC ) {
324+ self . received_msg ( wire:: Message :: UpdateFailMalformedHTLC ( msg. clone ( ) ) ) ;
325+ }
326+ fn handle_commitment_signed ( & self , _their_node_id : & PublicKey , msg : & msgs:: CommitmentSigned ) {
327+ self . received_msg ( wire:: Message :: CommitmentSigned ( msg. clone ( ) ) ) ;
328+ }
329+ fn handle_revoke_and_ack ( & self , _their_node_id : & PublicKey , msg : & msgs:: RevokeAndACK ) {
330+ self . received_msg ( wire:: Message :: RevokeAndACK ( msg. clone ( ) ) ) ;
331+ }
332+ fn handle_update_fee ( & self , _their_node_id : & PublicKey , msg : & msgs:: UpdateFee ) {
333+ self . received_msg ( wire:: Message :: UpdateFee ( msg. clone ( ) ) ) ;
334+ }
335+ fn handle_channel_update ( & self , _their_node_id : & PublicKey , _msg : & msgs:: ChannelUpdate ) {
336+ // Don't call `received_msg` here as `TestRoutingMessageHandler` generates these sometimes
337+ }
338+ fn handle_announcement_signatures ( & self , _their_node_id : & PublicKey , msg : & msgs:: AnnouncementSignatures ) {
339+ self . received_msg ( wire:: Message :: AnnouncementSignatures ( msg. clone ( ) ) ) ;
340+ }
341+ fn handle_channel_reestablish ( & self , _their_node_id : & PublicKey , msg : & msgs:: ChannelReestablish ) {
342+ self . received_msg ( wire:: Message :: ChannelReestablish ( msg. clone ( ) ) ) ;
343+ }
280344 fn peer_disconnected ( & self , _their_node_id : & PublicKey , _no_connection_possible : bool ) { }
281- fn peer_connected ( & self , _their_node_id : & PublicKey , _msg : & msgs:: Init ) { }
282- fn handle_error ( & self , _their_node_id : & PublicKey , _msg : & msgs:: ErrorMessage ) { }
345+ fn peer_connected ( & self , _their_node_id : & PublicKey , _msg : & msgs:: Init ) {
346+ // Don't bother with `received_msg` for Init as its auto-generated and we don't want to
347+ // bother re-generating the expected Init message in all tests.
348+ }
349+ fn handle_error ( & self , _their_node_id : & PublicKey , msg : & msgs:: ErrorMessage ) {
350+ self . received_msg ( wire:: Message :: Error ( msg. clone ( ) ) ) ;
351+ }
283352}
284353
285354impl events:: MessageSendEventsProvider for TestChannelMessageHandler {
0 commit comments