Skip to content

Commit 19817fc

Browse files
authored
feat: add Connection::on_close to get notified on close without keeping the connection alive (#153)
* feat: add Connection::on_closed * fix: make sure to always trigger on_closed senders * docs and tests * refactor: make OnClosed public, improve docs * fixup test * typo
1 parent 817a1b5 commit 19817fc

File tree

3 files changed

+170
-3
lines changed

3 files changed

+170
-3
lines changed

quinn/src/connection.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,24 @@ impl Connection {
518518
.clone()
519519
}
520520

521+
/// Wait for the connection to be closed without keeping a strong reference to the connection
522+
///
523+
/// Returns a future that resolves, once the connection is closed, to a tuple of
524+
/// ([`ConnectionError`], [`ConnectionStats`]).
525+
///
526+
/// Calling [`Self::closed`] keeps the connection alive until it is either closed locally via [`Connection::close`]
527+
/// or closed by the remote peer. This function instead does not keep the connection itself alive,
528+
/// so if all *other* clones of the connection are dropped, the connection will be closed implicitly even
529+
/// if there are futures returned from this function still being awaited.
530+
pub fn on_closed(&self) -> OnClosed {
531+
let (tx, rx) = oneshot::channel();
532+
self.0.state.lock("on_closed").on_closed.push(tx);
533+
OnClosed {
534+
conn: self.weak_handle(),
535+
rx,
536+
}
537+
}
538+
521539
/// If the connection is closed, the reason why.
522540
///
523541
/// Returns `None` if the connection is still open.
@@ -1037,6 +1055,43 @@ impl Future for SendDatagram<'_> {
10371055
}
10381056
}
10391057

1058+
/// Future returned by [`Connection::on_closed`]
1059+
///
1060+
/// Resolves to a tuple of ([`ConnectionError`], [`ConnectionStats`]).
1061+
pub struct OnClosed {
1062+
rx: oneshot::Receiver<(ConnectionError, ConnectionStats)>,
1063+
conn: WeakConnectionHandle,
1064+
}
1065+
1066+
impl Drop for OnClosed {
1067+
fn drop(&mut self) {
1068+
if self.rx.is_terminated() {
1069+
return;
1070+
};
1071+
if let Some(conn) = self.conn.upgrade() {
1072+
self.rx.close();
1073+
conn.0
1074+
.state
1075+
.lock("OnClosed::drop")
1076+
.on_closed
1077+
.retain(|tx| !tx.is_closed());
1078+
}
1079+
}
1080+
}
1081+
1082+
impl Future for OnClosed {
1083+
type Output = (ConnectionError, ConnectionStats);
1084+
1085+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1086+
let this = self.get_mut();
1087+
// The `expect` is safe because `State::drop` ensures that all senders are triggered
1088+
// before being dropped.
1089+
Pin::new(&mut this.rx)
1090+
.poll(cx)
1091+
.map(|x| x.expect("on_close sender is never dropped before sending"))
1092+
}
1093+
}
1094+
10401095
#[derive(Debug)]
10411096
pub(crate) struct ConnectionRef(Arc<ConnectionInner>);
10421097

@@ -1077,6 +1132,7 @@ impl ConnectionRef {
10771132
send_buffer: Vec::new(),
10781133
buffered_transmit: None,
10791134
observed_external_addr: watch::Sender::new(None),
1135+
on_closed: Vec::new(),
10801136
}),
10811137
shared: Shared::default(),
10821138
}))
@@ -1215,6 +1271,7 @@ pub(crate) struct State {
12151271
/// Our last external address reported by the peer. When multipath is enabled, this will be the
12161272
/// last report across all paths.
12171273
pub(crate) observed_external_addr: watch::Sender<Option<SocketAddr>>,
1274+
on_closed: Vec<oneshot::Sender<(ConnectionError, ConnectionStats)>>,
12181275
}
12191276

12201277
impl State {
@@ -1475,6 +1532,12 @@ impl State {
14751532
}
14761533
wake_all_notify(&mut self.stopped);
14771534
shared.closed.notify_waiters();
1535+
1536+
// Send to the registered on_closed futures.
1537+
let stats = self.inner.stats();
1538+
for tx in self.on_closed.drain(..) {
1539+
tx.send((reason.clone(), stats.clone())).ok();
1540+
}
14781541
}
14791542

14801543
fn close(&mut self, error_code: VarInt, reason: Bytes, shared: &Shared) {
@@ -1508,6 +1571,15 @@ impl Drop for State {
15081571
.endpoint_events
15091572
.send((self.handle, proto::EndpointEvent::drained()));
15101573
}
1574+
1575+
if !self.on_closed.is_empty() {
1576+
// Ensure that all on_closed oneshot senders are triggered before dropping.
1577+
let reason = self.error.as_ref().expect("closed without error reason");
1578+
let stats = self.inner.stats();
1579+
for tx in self.on_closed.drain(..) {
1580+
tx.send((reason.clone(), stats.clone())).ok();
1581+
}
1582+
}
15111583
}
15121584
}
15131585

quinn/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ pub use rustls;
7676
pub use udp;
7777

7878
pub use crate::connection::{
79-
AcceptBi, AcceptUni, Connecting, Connection, OpenBi, OpenUni, ReadDatagram, SendDatagram,
80-
SendDatagramError, WeakConnectionHandle, ZeroRttAccepted,
79+
AcceptBi, AcceptUni, Connecting, Connection, OnClosed, OpenBi, OpenUni, ReadDatagram,
80+
SendDatagram, SendDatagramError, WeakConnectionHandle, ZeroRttAccepted,
8181
};
8282
pub use crate::endpoint::{Accept, Endpoint, EndpointStats};
8383
pub use crate::incoming::{Incoming, IncomingFuture, RetryError};

quinn/src/tests.rs

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use std::{
1717
use crate::runtime::TokioRuntime;
1818
use crate::{Duration, Instant};
1919
use bytes::Bytes;
20-
use proto::{RandomConnectionIdGenerator, crypto::rustls::QuicClientConfig};
20+
use proto::{ConnectionError, RandomConnectionIdGenerator, crypto::rustls::QuicClientConfig};
2121
use rand::{RngCore, SeedableRng, rngs::StdRng};
2222
use rustls::{
2323
RootCertStore,
@@ -1023,3 +1023,98 @@ async fn test_multipath_observed_address() {
10231023

10241024
tokio::join!(server_task, client_task);
10251025
}
1026+
1027+
#[tokio::test]
1028+
async fn on_closed() {
1029+
let _guard = subscribe();
1030+
let endpoint = endpoint();
1031+
let endpoint2 = endpoint.clone();
1032+
let server_task = tokio::spawn(async move {
1033+
let conn = endpoint2
1034+
.accept()
1035+
.await
1036+
.expect("endpoint")
1037+
.await
1038+
.expect("connection");
1039+
let on_closed = conn.on_closed();
1040+
let cause = conn.closed().await;
1041+
let (cause1, _stats) = on_closed.await;
1042+
assert!(matches!(cause, ConnectionError::ApplicationClosed(_)));
1043+
assert!(matches!(cause1, ConnectionError::ApplicationClosed(_)));
1044+
});
1045+
let client_task = tokio::spawn(async move {
1046+
let conn = endpoint
1047+
.connect(endpoint.local_addr().unwrap(), "localhost")
1048+
.unwrap()
1049+
.await
1050+
.expect("connect");
1051+
let on_closed1 = conn.on_closed();
1052+
let on_closed2 = conn.on_closed();
1053+
drop(conn);
1054+
1055+
let (cause, _stats) = on_closed1.await;
1056+
assert_eq!(cause, ConnectionError::LocallyClosed);
1057+
let (cause, _stats) = on_closed2.await;
1058+
assert_eq!(cause, ConnectionError::LocallyClosed);
1059+
});
1060+
let (server_res, client_res) = tokio::join!(server_task, client_task);
1061+
server_res.expect("server task panicked");
1062+
client_res.expect("client task panicked");
1063+
}
1064+
1065+
#[tokio::test]
1066+
async fn on_closed_endpoint_drop() {
1067+
let _guard = subscribe();
1068+
let factory = EndpointFactory::new();
1069+
let client = factory.endpoint("client");
1070+
let server = factory.endpoint("server");
1071+
let server_addr = server.local_addr().unwrap();
1072+
let server_task = tokio::time::timeout(
1073+
Duration::from_millis(500),
1074+
tokio::spawn(async move {
1075+
let conn = server
1076+
.accept()
1077+
.await
1078+
.expect("endpoint")
1079+
.await
1080+
.expect("accept");
1081+
println!("accepted");
1082+
let on_closed = conn.on_closed();
1083+
drop(conn);
1084+
drop(server);
1085+
let (cause, _stats) = on_closed.await;
1086+
// Depending on timing we might have received a close frame or not.
1087+
assert!(matches!(
1088+
cause,
1089+
ConnectionError::ApplicationClosed(_) | ConnectionError::LocallyClosed
1090+
));
1091+
}),
1092+
);
1093+
let client_task = tokio::time::timeout(
1094+
Duration::from_millis(500),
1095+
tokio::spawn(async move {
1096+
let conn = client
1097+
.connect(server_addr, "localhost")
1098+
.unwrap()
1099+
.await
1100+
.expect("connect");
1101+
println!("connected");
1102+
let on_closed = conn.on_closed();
1103+
drop(conn);
1104+
drop(client);
1105+
let (cause, _stats) = on_closed.await;
1106+
// Depending on timing we might have received a close frame or not.
1107+
assert!(matches!(
1108+
cause,
1109+
ConnectionError::ApplicationClosed(_) | ConnectionError::LocallyClosed
1110+
));
1111+
}),
1112+
);
1113+
let (server_res, client_res) = tokio::join!(server_task, client_task);
1114+
server_res
1115+
.expect("server timeout")
1116+
.expect("server task panicked");
1117+
client_res
1118+
.expect("client timeout")
1119+
.expect("client task panicked");
1120+
}

0 commit comments

Comments
 (0)