@@ -518,6 +518,24 @@ impl Connection {
518518 . clone ( )
519519 }
520520
521+ /// Wait for the connection to be closed without keeping a strong reference to the connection
522+ ///
523+ /// Returns a future that resolves, once the connection is closed, to a tuple of
524+ /// ([`ConnectionError`], [`ConnectionStats`]).
525+ ///
526+ /// Calling [`Self::closed`] keeps the connection alive until it is either closed locally via [`Connection::close`]
527+ /// or closed by the remote peer. This function instead does not keep the connection itself alive,
528+ /// so if all *other* clones of the connection are dropped, the connection will be closed implicitly even
529+ /// if there are futures returned from this function still being awaited.
530+ pub fn on_closed ( & self ) -> OnClosed {
531+ let ( tx, rx) = oneshot:: channel ( ) ;
532+ self . 0 . state . lock ( "on_closed" ) . on_closed . push ( tx) ;
533+ OnClosed {
534+ conn : self . weak_handle ( ) ,
535+ rx,
536+ }
537+ }
538+
521539 /// If the connection is closed, the reason why.
522540 ///
523541 /// Returns `None` if the connection is still open.
@@ -1037,6 +1055,43 @@ impl Future for SendDatagram<'_> {
10371055 }
10381056}
10391057
1058+ /// Future returned by [`Connection::on_closed`]
1059+ ///
1060+ /// Resolves to a tuple of ([`ConnectionError`], [`ConnectionStats`]).
1061+ pub struct OnClosed {
1062+ rx : oneshot:: Receiver < ( ConnectionError , ConnectionStats ) > ,
1063+ conn : WeakConnectionHandle ,
1064+ }
1065+
1066+ impl Drop for OnClosed {
1067+ fn drop ( & mut self ) {
1068+ if self . rx . is_terminated ( ) {
1069+ return ;
1070+ } ;
1071+ if let Some ( conn) = self . conn . upgrade ( ) {
1072+ self . rx . close ( ) ;
1073+ conn. 0
1074+ . state
1075+ . lock ( "OnClosed::drop" )
1076+ . on_closed
1077+ . retain ( |tx| !tx. is_closed ( ) ) ;
1078+ }
1079+ }
1080+ }
1081+
1082+ impl Future for OnClosed {
1083+ type Output = ( ConnectionError , ConnectionStats ) ;
1084+
1085+ fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
1086+ let this = self . get_mut ( ) ;
1087+ // The `expect` is safe because `State::drop` ensures that all senders are triggered
1088+ // before being dropped.
1089+ Pin :: new ( & mut this. rx )
1090+ . poll ( cx)
1091+ . map ( |x| x. expect ( "on_close sender is never dropped before sending" ) )
1092+ }
1093+ }
1094+
10401095#[ derive( Debug ) ]
10411096pub ( crate ) struct ConnectionRef ( Arc < ConnectionInner > ) ;
10421097
@@ -1077,6 +1132,7 @@ impl ConnectionRef {
10771132 send_buffer : Vec :: new ( ) ,
10781133 buffered_transmit : None ,
10791134 observed_external_addr : watch:: Sender :: new ( None ) ,
1135+ on_closed : Vec :: new ( ) ,
10801136 } ) ,
10811137 shared : Shared :: default ( ) ,
10821138 } ) )
@@ -1215,6 +1271,7 @@ pub(crate) struct State {
12151271 /// Our last external address reported by the peer. When multipath is enabled, this will be the
12161272 /// last report across all paths.
12171273 pub ( crate ) observed_external_addr : watch:: Sender < Option < SocketAddr > > ,
1274+ on_closed : Vec < oneshot:: Sender < ( ConnectionError , ConnectionStats ) > > ,
12181275}
12191276
12201277impl State {
@@ -1475,6 +1532,12 @@ impl State {
14751532 }
14761533 wake_all_notify ( & mut self . stopped ) ;
14771534 shared. closed . notify_waiters ( ) ;
1535+
1536+ // Send to the registered on_closed futures.
1537+ let stats = self . inner . stats ( ) ;
1538+ for tx in self . on_closed . drain ( ..) {
1539+ tx. send ( ( reason. clone ( ) , stats. clone ( ) ) ) . ok ( ) ;
1540+ }
14781541 }
14791542
14801543 fn close ( & mut self , error_code : VarInt , reason : Bytes , shared : & Shared ) {
@@ -1508,6 +1571,15 @@ impl Drop for State {
15081571 . endpoint_events
15091572 . send ( ( self . handle , proto:: EndpointEvent :: drained ( ) ) ) ;
15101573 }
1574+
1575+ if !self . on_closed . is_empty ( ) {
1576+ // Ensure that all on_closed oneshot senders are triggered before dropping.
1577+ let reason = self . error . as_ref ( ) . expect ( "closed without error reason" ) ;
1578+ let stats = self . inner . stats ( ) ;
1579+ for tx in self . on_closed . drain ( ..) {
1580+ tx. send ( ( reason. clone ( ) , stats. clone ( ) ) ) . ok ( ) ;
1581+ }
1582+ }
15111583 }
15121584}
15131585
0 commit comments