Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): Allow keep alive to be turned off for a connection #1390

Merged
merged 1 commit into from
Dec 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/proto/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,14 @@ where I: AsyncRead + AsyncWrite,
pub fn close_write(&mut self) {
self.state.close_write();
}

pub fn disable_keep_alive(&mut self) {
if self.state.is_idle() {
self.state.close_read();
} else {
self.state.disable_keep_alive();
}
}
}

// ==== tokio_proto impl ====
Expand Down Expand Up @@ -700,6 +708,10 @@ impl<B, K: KeepAlive> State<B, K> {
}
}

fn disable_keep_alive(&mut self) {
self.keep_alive.disable()
}

fn busy(&mut self) {
if let KA::Disabled = self.keep_alive.status() {
return;
Expand Down Expand Up @@ -869,7 +881,7 @@ mod tests {
other => panic!("unexpected frame: {:?}", other)
}

// client
// client
let io = AsyncIo::new_buf(vec![], 1);
let mut conn = Conn::<_, proto::Chunk, ClientTransaction>::new(io, Default::default());
conn.state.busy();
Expand Down
4 changes: 4 additions & 0 deletions src/proto/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ where
}
}

pub fn disable_keep_alive(&mut self) {
self.conn.disable_keep_alive()
}

fn poll_read(&mut self) -> Poll<(), ::Error> {
loop {
if self.conn.can_read_head() {
Expand Down
12 changes: 12 additions & 0 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,18 @@ where
}
}

impl<I, B, S> Connection<I, S>
where S: Service<Request = Request, Response = Response<B>, Error = ::Error> + 'static,
I: AsyncRead + AsyncWrite + 'static,
B: Stream<Error=::Error> + 'static,
B::Item: AsRef<[u8]>,
{
/// Disables keep-alive for this connection.
pub fn disable_keep_alive(&mut self) {
self.conn.disable_keep_alive()
}
}

mod unnameable {
// This type is specifically not exported outside the crate,
// so no one can actually name the type. With no methods, we make no
Expand Down
109 changes: 107 additions & 2 deletions tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ extern crate pretty_env_logger;
extern crate tokio_core;

use futures::{Future, Stream};
use futures::future::{self, FutureResult};
use futures::future::{self, FutureResult, Either};
use futures::sync::oneshot;

use tokio_core::net::TcpListener;
Expand Down Expand Up @@ -551,6 +551,106 @@ fn pipeline_enabled() {
assert_eq!(n, 0);
}

#[test]
fn disable_keep_alive_mid_request() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();

let (tx1, rx1) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();

let child = thread::spawn(move || {
let mut req = connect(&addr);
req.write_all(b"GET / HTTP/1.1\r\n").unwrap();
tx1.send(()).unwrap();
rx2.wait().unwrap();
req.write_all(b"Host: localhost\r\n\r\n").unwrap();
let mut buf = vec![];
req.read_to_end(&mut buf).unwrap();
});

let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new().serve_connection(socket, HelloWorld)
.select2(rx1)
.then(|r| {
match r {
Ok(Either::A(_)) => panic!("expected rx first"),
Ok(Either::B(((), mut conn))) => {
conn.disable_keep_alive();
tx2.send(()).unwrap();
conn
}
Err(Either::A((e, _))) => panic!("unexpected error {}", e),
Err(Either::B((e, _))) => panic!("unexpected error {}", e),
}
})
});

core.run(fut).unwrap();
child.join().unwrap();
}

#[test]
fn disable_keep_alive_post_request() {
let mut core = Core::new().unwrap();
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap(), &core.handle()).unwrap();
let addr = listener.local_addr().unwrap();

let (tx1, rx1) = oneshot::channel();

let child = thread::spawn(move || {
let mut req = connect(&addr);
req.write_all(b"\
GET / HTTP/1.1\r\n\
Host: localhost\r\n\
\r\n\
").unwrap();

let mut buf = [0; 1024 * 8];
loop {
let n = req.read(&mut buf).expect("reading 1");
if n < buf.len() {
if &buf[n - HELLO.len()..n] == HELLO.as_bytes() {
break;
}
}
}

tx1.send(()).unwrap();

let nread = req.read(&mut buf).unwrap();
assert_eq!(nread, 0);
});

let fut = listener.incoming()
.into_future()
.map_err(|_| unreachable!())
.and_then(|(item, _incoming)| {
let (socket, _) = item.unwrap();
Http::<hyper::Chunk>::new().serve_connection(socket, HelloWorld)
.select2(rx1)
.then(|r| {
match r {
Ok(Either::A(_)) => panic!("expected rx first"),
Ok(Either::B(((), mut conn))) => {
conn.disable_keep_alive();
conn
}
Err(Either::A((e, _))) => panic!("unexpected error {}", e),
Err(Either::B((e, _))) => panic!("unexpected error {}", e),
}
})
});

core.run(fut).unwrap();
child.join().unwrap();
}

#[test]
fn no_proto_empty_parse_eof_does_not_return_error() {
let mut core = Core::new().unwrap();
Expand Down Expand Up @@ -719,6 +819,8 @@ impl Service for TestService {

}

const HELLO: &'static str = "hello";

struct HelloWorld;

impl Service for HelloWorld {
Expand All @@ -728,7 +830,10 @@ impl Service for HelloWorld {
type Future = FutureResult<Self::Response, Self::Error>;

fn call(&self, _req: Request) -> Self::Future {
future::ok(Response::new())
let mut response = Response::new();
response.headers_mut().set(hyper::header::ContentLength(HELLO.len() as u64));
response.set_body(HELLO);
future::ok(response)
}
}

Expand Down