diff --git a/linkerd/proxy/http/src/detect.rs b/linkerd/proxy/http/src/detect.rs index ed064c1556..b13a66ff95 100644 --- a/linkerd/proxy/http/src/detect.rs +++ b/linkerd/proxy/http/src/detect.rs @@ -5,8 +5,20 @@ use linkerd_error::Error; use linkerd_io::{self as io, AsyncReadExt}; use tracing::{debug, trace}; -const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; - +const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0"; +const SMALLEST_POSSIBLE_HTTP1_REQ: &str = "GET / HTTP/1.1"; + +/// Attempts to detect the HTTP version of a stream. +/// +/// This module biases towards availability instead of correctness. I.e. instead +/// of buffering until we can be sure that we're dealing with an HTTP stream, we +/// instead perform only a single read and use that data to inform protocol +/// hinting. If a single read doesn't provide enough data to make a decision, we +/// treat the protocol as unknown. +/// +/// This allows us to interoperate with protocols that send very small initial +/// messages. In rare situations, we may fail to properly detect that a stream is +/// HTTP. #[derive(Clone, Debug, Default)] pub struct DetectHttp(()); @@ -19,91 +31,47 @@ impl Detect for DetectHttp { io: &mut I, buf: &mut BytesMut, ) -> Result, Error> { - let mut scan_idx = 0; - let mut maybe_h1 = true; - let mut maybe_h2 = true; - - loop { - // Read data from the socket or timeout detection. - trace!( - capacity = buf.capacity(), - scan_idx, - maybe_h1, - maybe_h2, - "Reading" - ); - let sz = io.read_buf(buf).await?; - if sz == 0 { - // No data was read because the socket closed or the - // buffer capacity was exhausted. - debug!(read = buf.len(), "Could not detect protocol"); - return Ok(None); - } - - // HTTP/2 checking is faster because it's a simple string match. If - // we have enough data, check it first. In almost all cases, the - // whole preface should be available from the first read. - if maybe_h2 { - if buf.len() < H2_PREFACE.len() { - // Check the prefix we have already read to see if it looks likely to be HTTP/2. - maybe_h2 = buf[..] == H2_PREFACE[..buf.len()]; - } else { - trace!("Checking H2 preface"); - if &buf[..H2_PREFACE.len()] == H2_PREFACE { - trace!("Matched HTTP/2 prefix"); - return Ok(Some(Version::H2)); - } - - // Not a match. Don't check for an HTTP/2 preface again. - maybe_h2 = false; - } - } + trace!(capacity = buf.capacity(), "Reading"); + let sz = io.read_buf(buf).await?; + trace!(sz, "Read"); + if sz == 0 { + // No data was read because the socket closed or the + // buffer capacity was exhausted. + debug!(read = buf.len(), "Could not detect protocol"); + return Ok(None); + } - if maybe_h1 { - // Scan up to the first line ending to determine whether the - // request is HTTP/1.1. HTTP expects \r\n, so we just look for - // any \n to indicate that we can make a determination. - if buf[scan_idx..].contains(&b'\n') { - trace!("Found newline"); - // If the first line looks like an HTTP/2 first line, - // then we almost definitely got a fragmented first - // read. Only try HTTP/1 parsing if it doesn't look like - // HTTP/2. - if !maybe_h2 { - trace!("Parsing HTTP/1 message"); - // If we get to reading headers (and fail), the - // first line looked like an HTTP/1 request; so - // handle the stream as HTTP/1. - if let Ok(_) | Err(httparse::Error::TooManyHeaders) = - httparse::Request::new(&mut [httparse::EMPTY_HEADER; 0]).parse(&buf[..]) - { - trace!("Matched HTTP/1"); - return Ok(Some(Version::Http1)); - } - } - - // We found the EOL and it wasn't an HTTP/1.x request; - // stop scanning and don't scan again. - maybe_h1 = false; - } - - // Advance our scan index to the end of buffer so the next - // iteration starts scanning where we left off. - scan_idx = buf.len() - 1; + // HTTP/2 checking is faster because it's a simple string match. If we + // have enough data, check it first. We don't bother matching on the + // entire H2 preface because the first part is enough to get a clear + // signal. + if buf.len() >= H2_PREFACE.len() { + trace!("Checking H2 preface"); + if &buf[..H2_PREFACE.len()] == H2_PREFACE { + trace!("Matched HTTP/2 prefix"); + return Ok(Some(Version::H2)); } + } - if !maybe_h1 && !maybe_h2 { - trace!("Not HTTP"); - return Ok(None); + // Otherwise, we try to parse the data as an HTTP/1 message. + if buf.len() >= SMALLEST_POSSIBLE_HTTP1_REQ.len() { + trace!("Parsing HTTP/1 message"); + if let Ok(_) | Err(httparse::Error::TooManyHeaders) = + httparse::Request::new(&mut [httparse::EMPTY_HEADER; 0]).parse(&buf[..]) + { + trace!("Matched HTTP/1"); + return Ok(Some(Version::Http1)); } } + + trace!("Not HTTP"); + Ok(None) } } #[cfg(test)] mod tests { use super::*; - use bytes::BufMut; use tokio_test::io; const HTTP11_LINE: &[u8] = b"GET / HTTP/1.1\r\n"; @@ -111,70 +79,66 @@ mod tests { const GARBAGE: &[u8] = b"garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage"; - #[tokio::test(flavor = "current_thread")] + #[tokio::test] async fn h2() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - for i in 1..H2_PREFACE.len() { - let mut buf = BytesMut::with_capacity(H2_PREFACE.len()); - buf.put(H2_AND_GARBAGE); - debug!(read0 = ?std::str::from_utf8(&H2_AND_GARBAGE[..i]).unwrap()); - debug!(read1 = ?std::str::from_utf8(&H2_AND_GARBAGE[i..]).unwrap()); + for read in &[H2_PREFACE, H2_AND_GARBAGE] { + debug!(read = ?std::str::from_utf8(&read).unwrap()); let mut buf = BytesMut::with_capacity(1024); - let mut io = io::Builder::new() - .read(&H2_AND_GARBAGE[..i]) - .read(&H2_AND_GARBAGE[i..]) - .build(); + let mut io = io::Builder::new().read(&read).build(); let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); assert_eq!(kind, Some(Version::H2)); } } - #[tokio::test(flavor = "current_thread")] + #[tokio::test] async fn http1() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - for i in 1..HTTP11_LINE.len() { - debug!(read0 = ?std::str::from_utf8(&HTTP11_LINE[..i]).unwrap()); - debug!(read1 = ?std::str::from_utf8(&HTTP11_LINE[i..]).unwrap()); + // If we don't read enough to know + for i in 1..SMALLEST_POSSIBLE_HTTP1_REQ.len() { + debug!(read = ?std::str::from_utf8(&HTTP11_LINE[..i]).unwrap()); let mut buf = BytesMut::with_capacity(1024); - let mut io = io::Builder::new() - .read(&HTTP11_LINE[..i]) - .read(&HTTP11_LINE[i..]) - .build(); + let mut io = io::Builder::new().read(&HTTP11_LINE[..i]).build(); let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); - assert_eq!(kind, Some(Version::Http1)); + assert_eq!(kind, None); } - const REQ: &[u8] = b"GET /foo/bar/bar/blah HTTP/1.1\r\nHost: foob.example.com\r\n\r\n"; + debug!(read = ?std::str::from_utf8(&HTTP11_LINE).unwrap()); let mut buf = BytesMut::with_capacity(1024); - let mut io = io::Builder::new().read(&REQ).build(); + let mut io = io::Builder::new().read(&HTTP11_LINE).build(); let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); assert_eq!(kind, Some(Version::Http1)); - assert_eq!(&buf[..], REQ); + + const REQ: &[u8] = b"GET /foo/bar/bar/blah HTTP/1.1\r\nHost: foob.example.com\r\n\r\n"; + for i in SMALLEST_POSSIBLE_HTTP1_REQ.len()..REQ.len() { + debug!(read = ?std::str::from_utf8(&REQ[..i]).unwrap()); + let mut buf = BytesMut::with_capacity(1024); + let mut io = io::Builder::new().read(&REQ[..i]).build(); + let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); + assert_eq!(kind, Some(Version::Http1)); + assert_eq!(buf[..], REQ[..i]); + } // Starts with a P, like the h2 preface. const POST: &[u8] = b"POST /foo HTTP/1.1\r\n"; - for i in 1..POST.len() { + for i in SMALLEST_POSSIBLE_HTTP1_REQ.len()..POST.len() { let mut buf = BytesMut::with_capacity(1024); - let mut io = io::Builder::new().read(&POST[..i]).read(&POST[i..]).build(); - println!("buf0 = {:?}", &POST[..i]); - println!("buf1 = {:?}", &POST[i..]); + let mut io = io::Builder::new().read(&POST[..i]).build(); + debug!(read = ?std::str::from_utf8(&POST[..i]).unwrap()); let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); assert_eq!(kind, Some(Version::Http1)); - assert_eq!(&buf[..], POST); + assert_eq!(buf[..], POST[..i]); } } - #[tokio::test(flavor = "current_thread")] + #[tokio::test] async fn unknown() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); let mut buf = BytesMut::with_capacity(1024); - let mut io = io::Builder::new() - .read(b"foo.bar.blah\r") - .read(b"\nbobo") - .build(); + let mut io = io::Builder::new().read(b"foo.bar.blah\r\nbobo").build(); let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); assert_eq!(kind, None); assert_eq!(&buf[..], b"foo.bar.blah\r\nbobo"); @@ -184,11 +148,5 @@ mod tests { let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); assert_eq!(kind, None); assert_eq!(&buf[..], GARBAGE); - - let mut buf = BytesMut::with_capacity(1024); - let mut io = io::Builder::new().read(&HTTP11_LINE[..14]).build(); - let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); - assert_eq!(kind, None); - assert_eq!(&buf[..14], &HTTP11_LINE[..14]); } }