Skip to content

Commit

Permalink
Shut down endpoint driver when all references to it are dead
Browse files Browse the repository at this point in the history
  • Loading branch information
Ralith committed Mar 30, 2019
1 parent ba71632 commit a260e4f
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 86 deletions.
8 changes: 4 additions & 4 deletions quinn/src/builders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::borrow::Cow;
use std::io;
use std::net::ToSocketAddrs;
use std::str;
use std::sync::{Arc, Mutex};
use std::sync::Arc;

use err_derive::Error;
use quinn_proto as quinn;
Expand All @@ -11,7 +11,7 @@ use slog::Logger;

use quinn_proto::{EndpointConfig, ServerConfig, TransportConfig};

use crate::endpoint::{Driver, Endpoint, EndpointInner, Incoming};
use crate::endpoint::{Driver, Endpoint, EndpointRef, Incoming};
use crate::tls::{Certificate, CertificateChain, PrivateKey};
use crate::udp::UdpSocket;

Expand Down Expand Up @@ -55,7 +55,7 @@ impl<'a> EndpointBuilder<'a> {
};
let addr = socket.local_addr().map_err(EndpointError::Socket)?;
let socket = UdpSocket::from_std(socket, &reactor).map_err(EndpointError::Socket)?;
let rc = Arc::new(Mutex::new(EndpointInner::new(
let rc = EndpointRef::new(
self.logger.clone(),
socket,
quinn::Endpoint::new(
Expand All @@ -64,7 +64,7 @@ impl<'a> EndpointBuilder<'a> {
self.server_config.map(Arc::new),
)?,
addr.is_ipv6(),
)));
);
Ok((
Endpoint {
inner: rc.clone(),
Expand Down
89 changes: 66 additions & 23 deletions quinn/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{ConnectionEvent, EndpointEvent, IO_LOOP_BOUND};
/// May be cloned to obtain another handle to the same endpoint.
#[derive(Clone)]
pub struct Endpoint {
pub(crate) inner: Arc<Mutex<EndpointInner>>,
pub(crate) inner: EndpointRef,
pub(crate) default_client_config: ClientConfig,
}

Expand Down Expand Up @@ -116,7 +116,7 @@ impl Endpoint {
///
/// `Driver` instances do not terminate (always yields `NotReady`) except in case of an error.
#[must_use = "endpoint drivers must be spawned for I/O to occur"]
pub struct Driver(pub(crate) Arc<Mutex<EndpointInner>>);
pub struct Driver(pub(crate) EndpointRef);

impl Future for Driver {
type Item = ();
Expand All @@ -137,7 +137,13 @@ impl Future for Driver {
break;
}
}
Ok(Async::NotReady)
Ok(
if endpoint.unreferenced && endpoint.connections.is_empty() {
Async::Ready(())
} else {
Async::NotReady
},
)
}
}

Expand Down Expand Up @@ -165,27 +171,14 @@ pub(crate) struct EndpointInner {
// Stored to give out clones to new ConnectionInners
sender: mpsc::UnboundedSender<(ConnectionHandle, EndpointEvent)>,
events: mpsc::UnboundedReceiver<(ConnectionHandle, EndpointEvent)>,
/// Whether only one reference to this endpoint remains
///
/// We presume the final reference to always be the driver, because otherwise nothing we do will
/// have any effect regardless.
unreferenced: bool,
}

impl EndpointInner {
pub(crate) fn new(log: Logger, socket: UdpSocket, inner: quinn::Endpoint, ipv6: bool) -> Self {
let (sender, events) = mpsc::unbounded();
Self {
log,
socket,
inner,
ipv6,
sender,
events,
outgoing: VecDeque::new(),
incoming: VecDeque::new(),
incoming_reader: None,
incoming_live: true,
driver: None,
connections: FnvHashMap::default(),
}
}

fn drive_recv(&mut self, now: Instant) -> Result<bool, io::Error> {
let mut buf = [0; 64 * 1024];
let mut recvd = 0;
Expand Down Expand Up @@ -317,10 +310,10 @@ fn ensure_ipv6(x: SocketAddr) -> SocketAddrV6 {
}

/// Stream of incoming connections.
pub struct Incoming(Arc<Mutex<EndpointInner>>);
pub struct Incoming(EndpointRef);

impl Incoming {
pub(crate) fn new(inner: Arc<Mutex<EndpointInner>>) -> Self {
pub(crate) fn new(inner: EndpointRef) -> Self {
Self(inner)
}
}
Expand Down Expand Up @@ -350,3 +343,53 @@ impl Drop for Incoming {
}
}
}

pub(crate) struct EndpointRef(Arc<Mutex<EndpointInner>>);

impl EndpointRef {
pub(crate) fn new(log: Logger, socket: UdpSocket, inner: quinn::Endpoint, ipv6: bool) -> Self {
let (sender, events) = mpsc::unbounded();
Self(Arc::new(Mutex::new(EndpointInner {
log,
socket,
inner,
ipv6,
sender,
events,
outgoing: VecDeque::new(),
incoming: VecDeque::new(),
incoming_live: true,
incoming_reader: None,
driver: None,
connections: FnvHashMap::default(),
unreferenced: false,
})))
}
}

impl Clone for EndpointRef {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl Drop for EndpointRef {
fn drop(&mut self) {
if Arc::strong_count(&self.0) == 2 {
// If the driver is about to be on its own, arrange for it to shut down once the last
// connection is gone.
let endpoint = &mut *self.0.lock().unwrap();
endpoint.unreferenced = true;
if let Some(task) = endpoint.driver.take() {
task.notify();
}
}
}
}

impl std::ops::Deref for EndpointRef {
type Target = Mutex<EndpointInner>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
126 changes: 67 additions & 59 deletions quinn/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,69 +48,77 @@ fn echo_dualstack() {
}

fn run_echo(client_addr: SocketAddr, server_addr: SocketAddr) {
let log = logger();
let mut server_config = ServerConfigBuilder::default();
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]);
let key = crate::PrivateKey::from_der(&cert.serialize_private_key_der()).unwrap();
let cert = crate::Certificate::from_der(&cert.serialize_der()).unwrap();
let cert_chain = crate::CertificateChain::from_certs(vec![cert.clone()]);
server_config.certificate(cert_chain, key).unwrap();
let mut runtime = tokio::runtime::Runtime::new().unwrap();
{
let log = logger();
let mut server_config = ServerConfigBuilder::default();
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]);
let key = crate::PrivateKey::from_der(&cert.serialize_private_key_der()).unwrap();
let cert = crate::Certificate::from_der(&cert.serialize_der()).unwrap();
let cert_chain = crate::CertificateChain::from_certs(vec![cert.clone()]);
server_config.certificate(cert_chain, key).unwrap();

let mut server = Endpoint::new();
server.logger(log.clone());
server.listen(server_config.build());
let server_sock = UdpSocket::bind(server_addr).unwrap();
let server_addr = server_sock.local_addr().unwrap();
let (_, server_driver, server_incoming) = server.from_socket(server_sock).unwrap();
let mut server = Endpoint::new();
server.logger(log.clone());
server.listen(server_config.build());
let server_sock = UdpSocket::bind(server_addr).unwrap();
let server_addr = server_sock.local_addr().unwrap();
let (_, server_driver, server_incoming) = server.from_socket(server_sock).unwrap();

let mut client_config = ClientConfigBuilder::default();
client_config.add_certificate_authority(cert).unwrap();
client_config.enable_keylog();
let mut client = Endpoint::new();
client.logger(log.clone());
client.default_client_config(client_config.build());
let (client, client_driver, _) = client.bind(client_addr).unwrap();
let mut client_config = ClientConfigBuilder::default();
client_config.add_certificate_authority(cert).unwrap();
client_config.enable_keylog();
let mut client = Endpoint::new();
client.logger(log.clone());
client.default_client_config(client_config.build());
let (client, client_driver, _) = client.bind(client_addr).unwrap();

let mut runtime = tokio::runtime::Runtime::new().unwrap();
runtime.spawn(server_driver.map_err(|e| panic!("server driver failed: {}", e)));
runtime.spawn(client_driver.map_err(|e| panic!("client driver failed: {}", e)));
runtime.spawn(server_incoming.for_each(move |conn| {
tokio::spawn(conn.driver.map_err(|_| ()));
tokio::spawn(conn.incoming.map_err(|_| ()).for_each(echo));
Ok(())
}));
runtime.spawn(server_driver.map_err(|e| panic!("server driver failed: {}", e)));
runtime.spawn(client_driver.map_err(|e| panic!("client driver failed: {}", e)));
runtime.spawn(
server_incoming
.into_future()
.map(move |(conn, _)| {
let conn = conn.unwrap();
tokio::spawn(conn.driver.map_err(|_| ()));
tokio::spawn(conn.incoming.map_err(|_| ()).for_each(echo));
})
.map_err(|_| ()),
);

info!(log, "connecting from {} to {}", client_addr, server_addr);
runtime
.block_on(
client
.connect(&server_addr, "localhost")
.unwrap()
.map_err(|e| panic!("connection failed: {}", e))
.and_then(move |conn| {
tokio::spawn(conn.driver.map_err(|e| eprintln!("connection lost: {}", e)));
let conn = conn.connection;
let stream = conn.open_bi();
stream
.map_err(|_| ())
.and_then(move |stream| {
tokio::io::write_all(stream, b"foo".to_vec())
.map_err(|e| panic!("write: {}", e))
})
.and_then(|(stream, _)| {
tokio::io::shutdown(stream).map_err(|e| panic!("finish: {}", e))
})
.and_then(move |stream| {
read_to_end(stream, usize::max_value())
.map_err(|e| panic!("read: {}", e))
})
.and_then(move |(_, data)| {
assert_eq!(&data[..], b"foo");
conn.close(0, b"done").map_err(|_| unreachable!())
})
}),
)
.unwrap();
info!(log, "connecting from {} to {}", client_addr, server_addr);
runtime
.block_on(
client
.connect(&server_addr, "localhost")
.unwrap()
.map_err(|e| panic!("connection failed: {}", e))
.and_then(move |conn| {
tokio::spawn(conn.driver.map_err(|e| eprintln!("connection lost: {}", e)));
let conn = conn.connection;
let stream = conn.open_bi();
stream
.map_err(|_| ())
.and_then(move |stream| {
tokio::io::write_all(stream, b"foo".to_vec())
.map_err(|e| panic!("write: {}", e))
})
.and_then(|(stream, _)| {
tokio::io::shutdown(stream).map_err(|e| panic!("finish: {}", e))
})
.and_then(move |stream| {
read_to_end(stream, usize::max_value())
.map_err(|e| panic!("read: {}", e))
})
.and_then(move |(_, data)| {
assert_eq!(&data[..], b"foo");
conn.close(0, b"done").map_err(|_| unreachable!())
})
}),
)
.unwrap();
}
runtime.shutdown_on_idle().wait().unwrap();
}

fn echo(stream: NewStream) -> Box<impl Future<Item = (), Error = ()>> {
Expand Down

0 comments on commit a260e4f

Please sign in to comment.