Skip to content

Commit

Permalink
feat(server): Allow keep alive to be turned off for a connection
Browse files Browse the repository at this point in the history
  • Loading branch information
sfackler committed Dec 1, 2017
1 parent cecef9d commit 187f3a3
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 3 deletions.
14 changes: 13 additions & 1 deletion src/proto/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,14 @@ where I: AsyncRead + AsyncWrite,
pub fn close_write(&mut self) {
self.state.close_write();
}

pub fn disable_keep_alive(&mut self) {
if self.state.is_idle() {
self.state.close_read();
} else {
self.state.disable_keep_alive();
}
}
}

// ==== tokio_proto impl ====
Expand Down Expand Up @@ -700,6 +708,10 @@ impl<B, K: KeepAlive> State<B, K> {
}
}

fn disable_keep_alive(&mut self) {
self.keep_alive.disable()
}

fn busy(&mut self) {
if let KA::Disabled = self.keep_alive.status() {
return;
Expand Down Expand Up @@ -869,7 +881,7 @@ mod tests {
other => panic!("unexpected frame: {:?}", other)
}

// client
// client
let io = AsyncIo::new_buf(vec![], 1);
let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default());
conn.state.busy();
Expand Down
4 changes: 4 additions & 0 deletions src/proto/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ where
}
}

pub fn disable_keep_alive(&mut self) {
self.conn.disable_keep_alive()
}

fn poll_read(&mut self) -> Poll<(), ::Error> {
loop {
if self.conn.can_read_head() {
Expand Down
12 changes: 12 additions & 0 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,18 @@ where
}
}

impl<I, B, S> Connection<I, S>
where S: Service<Request = Request, Response = Response<B>, Error = ::Error> + 'static,
I: AsyncRead + AsyncWrite + 'static,
B: Stream<Error=::Error> + 'static,
B::Item: AsRef<[u8]>,
{
/// Disables keep-alive for this connection.
pub fn disable_keep_alive(&mut self) {
self.conn.disable_keep_alive()
}
}

mod unnameable {
// This type is specifically not exported outside the crate,
// so no one can actually name the type. With no methods, we make no
Expand Down
109 changes: 107 additions & 2 deletions tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ extern crate pretty_env_logger;
extern crate tokio_core;

use futures::{Future, Stream};
use futures::future::{self, FutureResult};
use futures::future::{self, FutureResult, Either};
use futures::sync::oneshot;

use tokio_core::net::TcpListener;
Expand Down Expand Up @@ -551,6 +551,106 @@ fn pipeline_enabled() {
assert_eq!(n, 0);
}

#[test]
fn disable_keep_alive_mid_request() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();

let (tx1, rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();

let child = thread::spawn(move || {
let mut req = connect(&addr);
req.write_all(b"GET / HTTP/1.1\r\n").unwrap();
tx1.send(()).unwrap();
rx2.wait().unwrap();
req.write_all(b"Host: localhost\r\n\r\n").unwrap();
let mut buf = vec![];
req.read_to_end(&mut buf).unwrap();
});

let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new().serve_connection(socket, HelloWorld)
.select2(rx1)
.then(|r| {
match r {
Ok(Either::A(_)) => panic!("expected rx first"),
Ok(Either::B(((), mut conn))) => {
conn.disable_keep_alive();
tx2.send(()).unwrap();
conn
}
Err(Either::A((e, _))) => panic!("unexpected error {}", e),
Err(Either::B((e, _))) => panic!("unexpected error {}", e),
}
})
});

core.run(fut).unwrap();
child.join().unwrap();
}

#[test]
fn disable_keep_alive_post_request() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();

let (tx1, rx1) = oneshot::channel();

let child = thread::spawn(move || {
let mut req = connect(&addr);
req.write_all(b"\
GET / HTTP/1.1\r\n\
Host: localhost\r\n\
\r\n\
").unwrap();

let mut buf = [0; 1024 * 8];
loop {
let n = req.read(&mut buf).expect("reading 1");
if n < buf.len() {
if &buf[n - HELLO.len()..n] == HELLO.as_bytes() {
break;
}
}
}

tx1.send(()).unwrap();

let nread = req.read(&mut buf).unwrap();
assert_eq!(nread, 0);
});

let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new().serve_connection(socket, HelloWorld)
.select2(rx1)
.then(|r| {
match r {
Ok(Either::A(_)) => panic!("expected rx first"),
Ok(Either::B(((), mut conn))) => {
conn.disable_keep_alive();
conn
}
Err(Either::A((e, _))) => panic!("unexpected error {}", e),
Err(Either::B((e, _))) => panic!("unexpected error {}", e),
}
})
});

core.run(fut).unwrap();
child.join().unwrap();
}

#[test]
fn no_proto_empty_parse_eof_does_not_return_error() {
let mut core = Core::new().unwrap();
Expand Down Expand Up @@ -719,6 +819,8 @@ impl Service for TestService {

}

const HELLO: &'static str = "hello";

struct HelloWorld;

impl Service for HelloWorld {
Expand All @@ -728,7 +830,10 @@ impl Service for HelloWorld {
type Future = FutureResult<Self::Response, Self::Error>;

fn call(&self, _req: Request) -> Self::Future {
future::ok(Response::new())
let mut response = Response::new();
response.headers_mut().set(hyper::header::ContentLength(HELLO.len() as u64));
response.set_body(HELLO);
future::ok(response)
}
}

Expand Down

0 comments on commit 187f3a3

Please sign in to comment.