diff --git a/quinn/src/builders.rs b/quinn/src/builders.rs index d6791253d..cae080faa 100644 --- a/quinn/src/builders.rs +++ b/quinn/src/builders.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; use std::io; use std::net::ToSocketAddrs; use std::str; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use err_derive::Error; use quinn_proto as quinn; @@ -11,7 +11,7 @@ use slog::Logger; use quinn_proto::{EndpointConfig, ServerConfig, TransportConfig}; -use crate::endpoint::{Driver, Endpoint, EndpointInner, Incoming}; +use crate::endpoint::{Driver, Endpoint, EndpointRef, Incoming}; use crate::tls::{Certificate, CertificateChain, PrivateKey}; use crate::udp::UdpSocket; @@ -55,7 +55,7 @@ impl<'a> EndpointBuilder<'a> { }; let addr = socket.local_addr().map_err(EndpointError::Socket)?; let socket = UdpSocket::from_std(socket, &reactor).map_err(EndpointError::Socket)?; - let rc = Arc::new(Mutex::new(EndpointInner::new( + let rc = EndpointRef::new( self.logger.clone(), socket, quinn::Endpoint::new( @@ -64,7 +64,7 @@ impl<'a> EndpointBuilder<'a> { self.server_config.map(Arc::new), )?, addr.is_ipv6(), - ))); + ); Ok(( Endpoint { inner: rc.clone(), diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 97371a35e..f26f671aa 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -34,7 +34,7 @@ use crate::{ConnectionEvent, EndpointEvent, IO_LOOP_BOUND}; /// May be cloned to obtain another handle to the same endpoint. #[derive(Clone)] pub struct Endpoint { - pub(crate) inner: Arc>, + pub(crate) inner: EndpointRef, pub(crate) default_client_config: ClientConfig, } @@ -116,7 +116,7 @@ impl Endpoint { /// /// `Driver` instances do not terminate (always yields `NotReady`) except in case of an error. #[must_use = "endpoint drivers must be spawned for I/O to occur"] -pub struct Driver(pub(crate) Arc>); +pub struct Driver(pub(crate) EndpointRef); impl Future for Driver { type Item = (); @@ -137,7 +137,13 @@ impl Future for Driver { break; } } - Ok(Async::NotReady) + Ok( + if endpoint.unreferenced && endpoint.connections.is_empty() { + Async::Ready(()) + } else { + Async::NotReady + }, + ) } } @@ -165,27 +171,14 @@ pub(crate) struct EndpointInner { // Stored to give out clones to new ConnectionInners sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>, events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>, + /// Whether only one reference to this endpoint remains + /// + /// We presume the final reference to always be the driver, because otherwise nothing we do will + /// have any effect regardless. + unreferenced: bool, } impl EndpointInner { - pub(crate) fn new(log: Logger, socket: UdpSocket, inner: quinn::Endpoint, ipv6: bool) -> Self { - let (sender, events) = mpsc::unbounded(); - Self { - log, - socket, - inner, - ipv6, - sender, - events, - outgoing: VecDeque::new(), - incoming: VecDeque::new(), - incoming_reader: None, - incoming_live: true, - driver: None, - connections: FnvHashMap::default(), - } - } - fn drive_recv(&mut self, now: Instant) -> Result { let mut buf = [0; 64 * 1024]; let mut recvd = 0; @@ -317,10 +310,10 @@ fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 { } /// Stream of incoming connections. -pub struct Incoming(Arc>); +pub struct Incoming(EndpointRef); impl Incoming { - pub(crate) fn new(inner: Arc>) -> Self { + pub(crate) fn new(inner: EndpointRef) -> Self { Self(inner) } } @@ -350,3 +343,53 @@ impl Drop for Incoming { } } } + +pub(crate) struct EndpointRef(Arc>); + +impl EndpointRef { + pub(crate) fn new(log: Logger, socket: UdpSocket, inner: quinn::Endpoint, ipv6: bool) -> Self { + let (sender, events) = mpsc::unbounded(); + Self(Arc::new(Mutex::new(EndpointInner { + log, + socket, + inner, + ipv6, + sender, + events, + outgoing: VecDeque::new(), + incoming: VecDeque::new(), + incoming_live: true, + incoming_reader: None, + driver: None, + connections: FnvHashMap::default(), + unreferenced: false, + }))) + } +} + +impl Clone for EndpointRef { + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Drop for EndpointRef { + fn drop(&mut self) { + if Arc::strong_count(&self.0) == 2 { + // If the driver is about to be on its own, arrange for it to shut down once the last + // connection is gone. + let endpoint = &mut *self.0.lock().unwrap(); + endpoint.unreferenced = true; + if let Some(task) = endpoint.driver.take() { + task.notify(); + } + } + } +} + +impl std::ops::Deref for EndpointRef { + type Target = Mutex; + fn deref(&self) -> &Self::Target { + &self.0 + } +} diff --git a/quinn/src/tests.rs b/quinn/src/tests.rs index f1d12492b..93146e616 100644 --- a/quinn/src/tests.rs +++ b/quinn/src/tests.rs @@ -48,69 +48,77 @@ fn echo_dualstack() { } fn run_echo(client_addr: SocketAddr, server_addr: SocketAddr) { - let log = logger(); - let mut server_config = ServerConfigBuilder::default(); - let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]); - let key = crate::PrivateKey::from_der(&cert.serialize_private_key_der()).unwrap(); - let cert = crate::Certificate::from_der(&cert.serialize_der()).unwrap(); - let cert_chain = crate::CertificateChain::from_certs(vec![cert.clone()]); - server_config.certificate(cert_chain, key).unwrap(); + let mut runtime = tokio::runtime::Runtime::new().unwrap(); + { + let log = logger(); + let mut server_config = ServerConfigBuilder::default(); + let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]); + let key = crate::PrivateKey::from_der(&cert.serialize_private_key_der()).unwrap(); + let cert = crate::Certificate::from_der(&cert.serialize_der()).unwrap(); + let cert_chain = crate::CertificateChain::from_certs(vec![cert.clone()]); + server_config.certificate(cert_chain, key).unwrap(); - let mut server = Endpoint::new(); - server.logger(log.clone()); - server.listen(server_config.build()); - let server_sock = UdpSocket::bind(server_addr).unwrap(); - let server_addr = server_sock.local_addr().unwrap(); - let (_, server_driver, server_incoming) = server.from_socket(server_sock).unwrap(); + let mut server = Endpoint::new(); + server.logger(log.clone()); + server.listen(server_config.build()); + let server_sock = UdpSocket::bind(server_addr).unwrap(); + let server_addr = server_sock.local_addr().unwrap(); + let (_, server_driver, server_incoming) = server.from_socket(server_sock).unwrap(); - let mut client_config = ClientConfigBuilder::default(); - client_config.add_certificate_authority(cert).unwrap(); - client_config.enable_keylog(); - let mut client = Endpoint::new(); - client.logger(log.clone()); - client.default_client_config(client_config.build()); - let (client, client_driver, _) = client.bind(client_addr).unwrap(); + let mut client_config = ClientConfigBuilder::default(); + client_config.add_certificate_authority(cert).unwrap(); + client_config.enable_keylog(); + let mut client = Endpoint::new(); + client.logger(log.clone()); + client.default_client_config(client_config.build()); + let (client, client_driver, _) = client.bind(client_addr).unwrap(); - let mut runtime = tokio::runtime::Runtime::new().unwrap(); - runtime.spawn(server_driver.map_err(|e| panic!("server driver failed: {}", e))); - runtime.spawn(client_driver.map_err(|e| panic!("client driver failed: {}", e))); - runtime.spawn(server_incoming.for_each(move |conn| { - tokio::spawn(conn.driver.map_err(|_| ())); - tokio::spawn(conn.incoming.map_err(|_| ()).for_each(echo)); - Ok(()) - })); + runtime.spawn(server_driver.map_err(|e| panic!("server driver failed: {}", e))); + runtime.spawn(client_driver.map_err(|e| panic!("client driver failed: {}", e))); + runtime.spawn( + server_incoming + .into_future() + .map(move |(conn, _)| { + let conn = conn.unwrap(); + tokio::spawn(conn.driver.map_err(|_| ())); + tokio::spawn(conn.incoming.map_err(|_| ()).for_each(echo)); + }) + .map_err(|_| ()), + ); - info!(log, "connecting from {} to {}", client_addr, server_addr); - runtime - .block_on( - client - .connect(&server_addr, "localhost") - .unwrap() - .map_err(|e| panic!("connection failed: {}", e)) - .and_then(move |conn| { - tokio::spawn(conn.driver.map_err(|e| eprintln!("connection lost: {}", e))); - let conn = conn.connection; - let stream = conn.open_bi(); - stream - .map_err(|_| ()) - .and_then(move |stream| { - tokio::io::write_all(stream, b"foo".to_vec()) - .map_err(|e| panic!("write: {}", e)) - }) - .and_then(|(stream, _)| { - tokio::io::shutdown(stream).map_err(|e| panic!("finish: {}", e)) - }) - .and_then(move |stream| { - read_to_end(stream, usize::max_value()) - .map_err(|e| panic!("read: {}", e)) - }) - .and_then(move |(_, data)| { - assert_eq!(&data[..], b"foo"); - conn.close(0, b"done").map_err(|_| unreachable!()) - }) - }), - ) - .unwrap(); + info!(log, "connecting from {} to {}", client_addr, server_addr); + runtime + .block_on( + client + .connect(&server_addr, "localhost") + .unwrap() + .map_err(|e| panic!("connection failed: {}", e)) + .and_then(move |conn| { + tokio::spawn(conn.driver.map_err(|e| eprintln!("connection lost: {}", e))); + let conn = conn.connection; + let stream = conn.open_bi(); + stream + .map_err(|_| ()) + .and_then(move |stream| { + tokio::io::write_all(stream, b"foo".to_vec()) + .map_err(|e| panic!("write: {}", e)) + }) + .and_then(|(stream, _)| { + tokio::io::shutdown(stream).map_err(|e| panic!("finish: {}", e)) + }) + .and_then(move |stream| { + read_to_end(stream, usize::max_value()) + .map_err(|e| panic!("read: {}", e)) + }) + .and_then(move |(_, data)| { + assert_eq!(&data[..], b"foo"); + conn.close(0, b"done").map_err(|_| unreachable!()) + }) + }), + ) + .unwrap(); + } + runtime.shutdown_on_idle().wait().unwrap(); } fn echo(stream: NewStream) -> Box> {