diff --git a/Cargo.lock b/Cargo.lock index 893876725c..d2bf5d32ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1360,6 +1360,7 @@ dependencies = [ name = "linkerd2-proxy-http" version = "0.1.0" dependencies = [ + "async-trait", "bytes 0.6.0", "futures 0.3.5", "h2 0.3.0", @@ -1375,14 +1376,17 @@ dependencies = [ "linkerd2-http-box", "linkerd2-identity", "linkerd2-io", + "linkerd2-proxy-transport", "linkerd2-stack", "linkerd2-timeout", "pin-project 1.0.2", "rand 0.7.2", "tokio 0.3.5", + "tokio-test 0.3.0", "tower", "tracing", "tracing-futures", + "tracing-subscriber", "try-lock", ] @@ -1459,6 +1463,7 @@ name = "linkerd2-proxy-transport" version = "0.1.0" dependencies = [ "async-stream 0.2.1", + "async-trait", "bytes 0.6.0", "futures 0.3.5", "indexmap", diff --git a/linkerd/app/core/src/metrics.rs b/linkerd/app/core/src/metrics.rs index 39d57a8d44..b8f3605745 100644 --- a/linkerd/app/core/src/metrics.rs +++ b/linkerd/app/core/src/metrics.rs @@ -56,6 +56,7 @@ pub struct EndpointLabels { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct StackLabels { pub direction: Direction, + pub protocol: &'static str, pub name: &'static str, } @@ -297,17 +298,19 @@ impl FmtLabels for TlsId { // === impl StackLabels === impl StackLabels { - pub fn inbound(name: &'static str) -> Self { + pub fn inbound(protocol: &'static str, name: &'static str) -> Self { Self { - direction: Direction::In, name, + protocol, + direction: Direction::In, } } - pub fn outbound(name: &'static str) -> Self { + pub fn outbound(protocol: &'static str, name: &'static str) -> Self { Self { - direction: Direction::Out, name, + protocol, + direction: Direction::Out, } } } @@ -315,6 +318,6 @@ impl StackLabels { impl FmtLabels for StackLabels { fn fmt_labels(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.direction.fmt_labels(f)?; - write!(f, ",name=\"{}\"", self.name) + write!(f, ",protocol=\"{}\",name=\"{}\"", self.protocol, self.name) } } diff --git a/linkerd/app/inbound/src/lib.rs b/linkerd/app/inbound/src/lib.rs index 7a397bdfdc..b29524358a 100644 --- a/linkerd/app/inbound/src/lib.rs +++ b/linkerd/app/inbound/src/lib.rs @@ -280,7 +280,7 @@ impl Config { svc::layers() .push_failfast(dispatch_timeout) .push_spawn_buffer(buffer_capacity) - .push(metrics.stack.layer(stack_labels("logical"))), + .push(metrics.stack.layer(stack_labels("http", "logical"))), ) .push_cache(cache_max_idle_age) .push_on_response( @@ -317,10 +317,10 @@ impl Config { where I: io::AsyncRead + io::AsyncWrite + io::PeerAddr + Unpin + Send + 'static, F: svc::NewService + Unpin + Clone + Send + 'static, - A: tower::Service, Response = ()> + Clone + Send + 'static, + A: tower::Service, Response = ()> + Clone + Send + Sync + 'static, A::Error: Into, A::Future: Send, - H: svc::NewService + Unpin + Clone + Send + 'static, + H: svc::NewService + Unpin + Clone + Send + Sync + 'static, S: tower::Service< http::Request, Response = http::Response, @@ -335,7 +335,7 @@ impl Config { dispatch_timeout, max_in_flight_requests, detect_protocol_timeout, - buffer_capacity, + cache_max_idle_age, .. } = self.proxy.clone(); @@ -371,7 +371,7 @@ impl Config { .push(TraceContext::layer(span_sink.map(|span_sink| { SpanConverter::server(span_sink, trace_labels()) }))) - .push(metrics.stack.layer(stack_labels("source"))) + .push(metrics.stack.layer(stack_labels("http", "server"))) .box_http_request() .box_http_response(), ) @@ -389,10 +389,12 @@ impl Config { .into_inner(), drain.clone(), )) - .push_on_response(svc::layers().push_spawn_buffer(buffer_capacity).push( - transport::Prefix::layer( - http::Version::DETECT_BUFFER_CAPACITY, + .check_new_clone::<(Option, TcpAccept)>() + .push_cache(cache_max_idle_age) + .push(transport::NewDetectService::layer( + transport::detect::DetectTimeout::new( detect_protocol_timeout, + http::DetectHttp::default(), ), )) .into_inner() @@ -458,8 +460,8 @@ pub fn trace_labels() -> HashMap { l } -fn stack_labels(name: &'static str) -> metrics::StackLabels { - metrics::StackLabels::inbound(name) +fn stack_labels(proto: &'static str, name: &'static str) -> metrics::StackLabels { + metrics::StackLabels::inbound(proto, name) } // === impl SkipByPort === diff --git a/linkerd/app/integration/src/tests/identity.rs b/linkerd/app/integration/src/tests/identity.rs index bbcc72dd9b..d236ae5ac9 100644 --- a/linkerd/app/integration/src/tests/identity.rs +++ b/linkerd/app/integration/src/tests/identity.rs @@ -25,7 +25,7 @@ async fn nonblocking_identity_detection() { .await; let proxy = proxy::new().identity(id_svc); - let msg1 = "custom tcp hello"; + let msg1 = "custom tcp hello\n"; let msg2 = "custom tcp bye"; let srv = server::tcp() .accept(move |read| { diff --git a/linkerd/app/integration/src/tests/shutdown.rs b/linkerd/app/integration/src/tests/shutdown.rs index dfa4c3b0fa..ab52bb0b5d 100644 --- a/linkerd/app/integration/src/tests/shutdown.rs +++ b/linkerd/app/integration/src/tests/shutdown.rs @@ -100,7 +100,7 @@ async fn tcp_waits_for_proxies_to_close() { let _trace = trace_init(); let (shdn, rx) = shutdown_signal(); - let msg1 = "custom tcp hello"; + let msg1 = "custom tcp hello\n"; let msg2 = "custom tcp bye"; let srv = server::tcp() diff --git a/linkerd/app/integration/src/tests/telemetry.rs b/linkerd/app/integration/src/tests/telemetry.rs index b9901cee6b..ac08c8b27f 100644 --- a/linkerd/app/integration/src/tests/telemetry.rs +++ b/linkerd/app/integration/src/tests/telemetry.rs @@ -76,7 +76,7 @@ impl Fixture { } impl TcpFixture { - const HELLO_MSG: &'static str = "custom tcp hello"; + const HELLO_MSG: &'static str = "custom tcp hello\n"; const BYE_MSG: &'static str = "custom tcp bye"; async fn server() -> server::Listening { diff --git a/linkerd/app/integration/src/tests/transparency.rs b/linkerd/app/integration/src/tests/transparency.rs index de229d8499..24aee7b0ff 100644 --- a/linkerd/app/integration/src/tests/transparency.rs +++ b/linkerd/app/integration/src/tests/transparency.rs @@ -50,7 +50,7 @@ async fn inbound_http1() { async fn outbound_tcp() { let _trace = trace_init(); - let msg1 = "custom tcp hello"; + let msg1 = "custom tcp hello\n"; let msg2 = "custom tcp bye"; let srv = server::tcp() @@ -88,7 +88,7 @@ async fn outbound_tcp() { async fn outbound_tcp_external() { let _trace = trace_init(); - let msg1 = "custom tcp hello"; + let msg1 = "custom tcp hello\n"; let msg2 = "custom tcp bye"; let srv = server::tcp() @@ -127,7 +127,7 @@ async fn outbound_tcp_external() { async fn inbound_tcp() { let _trace = trace_init(); - let msg1 = "custom tcp hello"; + let msg1 = "custom tcp hello\n"; let msg2 = "custom tcp bye"; let srv = server::tcp() @@ -296,7 +296,7 @@ async fn tcp_server_first_tls() { async fn tcp_connections_close_if_client_closes() { let _trace = trace_init(); - let msg1 = "custom tcp hello"; + let msg1 = "custom tcp hello\n"; let msg2 = "custom tcp bye"; let (mut tx, mut rx) = mpsc::channel(1); diff --git a/linkerd/app/outbound/src/http/logical.rs b/linkerd/app/outbound/src/http/logical.rs index 23cf19d65d..f8053ec20a 100644 --- a/linkerd/app/outbound/src/http/logical.rs +++ b/linkerd/app/outbound/src/http/logical.rs @@ -55,7 +55,11 @@ where .push_on_response( svc::layers() .push(svc::layer::mk(svc::SpawnReady::new)) - .push(metrics.stack.layer(stack_labels("balance.endpoint"))) + .push( + metrics + .stack + .layer(stack_labels("http", "balance.endpoint")), + ) .box_http_request(), ) .check_new_service::>() @@ -71,7 +75,7 @@ where // If the balancer has been empty/unavailable for 10s, eagerly fail // requests. .push_failfast(dispatch_timeout) - .push(metrics.stack.layer(stack_labels("concrete"))), + .push(metrics.stack.layer(stack_labels("http", "concrete"))), ) .into_new_service() .check_new_service::>() diff --git a/linkerd/app/outbound/src/ingress.rs b/linkerd/app/outbound/src/ingress.rs index bfd89d047b..35805225d9 100644 --- a/linkerd/app/outbound/src/ingress.rs +++ b/linkerd/app/outbound/src/ingress.rs @@ -46,6 +46,7 @@ where Error = Error, > + Clone + Send + + Sync + 'static, TSvc::Future: Send, H: svc::NewService + Unpin + Clone + Send + Sync + 'static, @@ -110,7 +111,7 @@ where .push(TraceContext::layer(span_sink.clone().map(|span_sink| { SpanConverter::server(span_sink, trace_labels()) }))) - .push(metrics.stack.layer(stack_labels("source"))) + .push(metrics.stack.layer(stack_labels("http", "server"))) .push_failfast(dispatch_timeout) .push_spawn_buffer(buffer_capacity) .box_http_response(), @@ -130,11 +131,13 @@ where .into_inner(); svc::stack(http::NewServeHttp::new(h2_settings, http, tcp, drain)) - .check_new_service::>>() - .push_on_response(svc::layers().push_spawn_buffer(buffer_capacity).push( - transport::Prefix::layer( - http::Version::DETECT_BUFFER_CAPACITY, + .check_new_service::<(Option, tcp::Accept), io::PrefixedIo>>() + .check_new_clone::<(Option, tcp::Accept)>() + .push_cache(cache_max_idle_age) + .push(transport::NewDetectService::layer( + transport::detect::DetectTimeout::new( detect_protocol_timeout, + http::DetectHttp::default(), ), )) .check_new_service::>() diff --git a/linkerd/app/outbound/src/lib.rs b/linkerd/app/outbound/src/lib.rs index c4135975ef..acdbcb64c8 100644 --- a/linkerd/app/outbound/src/lib.rs +++ b/linkerd/app/outbound/src/lib.rs @@ -26,8 +26,8 @@ pub struct Config { pub allow_discovery: AddrMatch, } -fn stack_labels(name: &'static str) -> metrics::StackLabels { - metrics::StackLabels::outbound(name) +fn stack_labels(proto: &'static str, name: &'static str) -> metrics::StackLabels { + metrics::StackLabels::outbound(proto, name) } pub fn trace_labels() -> HashMap { diff --git a/linkerd/app/outbound/src/server.rs b/linkerd/app/outbound/src/server.rs index 661e1703d4..68fcb7334b 100644 --- a/linkerd/app/outbound/src/server.rs +++ b/linkerd/app/outbound/src/server.rs @@ -156,6 +156,7 @@ where max_in_flight_requests, detect_protocol_timeout, buffer_capacity, + cache_max_idle_age, .. } = config.proxy.clone(); @@ -176,7 +177,7 @@ where .push(TraceContext::layer(span_sink.clone().map(|span_sink| { SpanConverter::server(span_sink, trace_labels()) }))) - .push(metrics.stack.layer(stack_labels("source"))) + .push(metrics.stack.layer(stack_labels("http", "server"))) .push_failfast(dispatch_timeout) .push_spawn_buffer(buffer_capacity) .box_http_response(), @@ -193,6 +194,7 @@ where .check_make_service::() .push_on_response(svc::layer::mk(tcp::Forward::new)) .into_new_service() + .push_on_response(metrics.stack.layer(stack_labels("tcp", "forward"))) .check_new_service::>>() .push_map_target(tcp::Endpoint::from_logical( tls::ReasonForNoPeerName::NotProvidedByServiceDiscovery, @@ -221,11 +223,13 @@ where tcp_balance, drain.clone(), )) - .check_new_service::>>() - .push_on_response(svc::layers().push_spawn_buffer(buffer_capacity).push( - transport::Prefix::layer( - http::Version::DETECT_BUFFER_CAPACITY, + .check_new_clone::<(Option, tcp::Logical)>() + .check_new_service::<(Option, tcp::Logical), transport::io::PrefixedIo>>() + .push_cache(cache_max_idle_age) + .push(transport::NewDetectService::layer( + transport::detect::DetectTimeout::new( detect_protocol_timeout, + http::DetectHttp::default(), ), )) .check_new_service::>() @@ -235,6 +239,7 @@ where .push_map_target(tcp::Endpoint::from_logical( tls::ReasonForNoPeerName::PortSkipped, )) + .push_on_response(metrics.stack.layer(stack_labels("tcp", "opaque"))) .check_new_service::>() .into_inner(); diff --git a/linkerd/app/outbound/src/target.rs b/linkerd/app/outbound/src/target.rs index fe3f583180..eaa3c95a37 100644 --- a/linkerd/app/outbound/src/target.rs +++ b/linkerd/app/outbound/src/target.rs @@ -127,6 +127,21 @@ impl

Logical

{ } } +impl PartialEq> for Logical

{ + fn eq(&self, other: &Logical

) -> bool { + self.orig_dst == other.orig_dst && self.protocol == other.protocol + } +} + +impl Eq for Logical

{} + +impl std::hash::Hash for Logical

{ + fn hash(&self, state: &mut H) { + self.orig_dst.hash(state); + self.protocol.hash(state); + } +} + impl std::fmt::Debug for Logical

{ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Logical") diff --git a/linkerd/app/outbound/src/tcp/tests.rs b/linkerd/app/outbound/src/tcp/tests.rs index d8423ca67f..4cbbfa66af 100644 --- a/linkerd/app/outbound/src/tcp/tests.rs +++ b/linkerd/app/outbound/src/tcp/tests.rs @@ -865,7 +865,7 @@ where svc }; async move { - let io = support::io().read(b"hello\r\n").write(b"world").build(); + let io = support::io().read(b"hello\n").write(b"world").build(); let res = svc.oneshot(io).err_into::().await; tracing::trace!(?res); if let Err(err) = res { diff --git a/linkerd/app/src/env.rs b/linkerd/app/src/env.rs index e2a89d6ee1..20451d3cf3 100644 --- a/linkerd/app/src/env.rs +++ b/linkerd/app/src/env.rs @@ -170,14 +170,14 @@ pub const DEFAULT_INBOUND_LISTEN_ADDR: &str = "0.0.0.0:4143"; pub const DEFAULT_CONTROL_LISTEN_ADDR: &str = "0.0.0.0:4190"; const DEFAULT_ADMIN_LISTEN_ADDR: &str = "127.0.0.1:4191"; const DEFAULT_METRICS_RETAIN_IDLE: Duration = Duration::from_secs(10 * 60); -const DEFAULT_INBOUND_DISPATCH_TIMEOUT: Duration = Duration::from_secs(1); +const DEFAULT_INBOUND_DISPATCH_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_INBOUND_CONNECT_TIMEOUT: Duration = Duration::from_millis(100); const DEFAULT_INBOUND_CONNECT_BACKOFF: ExponentialBackoff = ExponentialBackoff { min: Duration::from_millis(100), max: Duration::from_millis(500), jitter: 0.1, }; -const DEFAULT_OUTBOUND_DISPATCH_TIMEOUT: Duration = Duration::from_secs(3); +const DEFAULT_OUTBOUND_DISPATCH_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_OUTBOUND_CONNECT_TIMEOUT: Duration = Duration::from_secs(1); const DEFAULT_OUTBOUND_CONNECT_BACKOFF: ExponentialBackoff = ExponentialBackoff { min: Duration::from_millis(100), diff --git a/linkerd/concurrency-limit/src/lib.rs b/linkerd/concurrency-limit/src/lib.rs index 6095b6f59c..736ff922de 100644 --- a/linkerd/concurrency-limit/src/lib.rs +++ b/linkerd/concurrency-limit/src/lib.rs @@ -33,7 +33,7 @@ pub struct ConcurrencyLimit { } enum State { - Waiting(Pin + Send + 'static>>), + Waiting(Pin + Send + Sync + 'static>>), Ready(OwnedSemaphorePermit), Empty, } diff --git a/linkerd/io/src/lib.rs b/linkerd/io/src/lib.rs index 7ff84f6467..60e46c87e1 100644 --- a/linkerd/io/src/lib.rs +++ b/linkerd/io/src/lib.rs @@ -1,11 +1,9 @@ mod boxed; -mod peek; mod prefixed; mod sensor; pub use self::{ boxed::BoxedIo, - peek::{Peek, Peekable}, prefixed::PrefixedIo, sensor::{Sensor, SensorIo}, }; diff --git a/linkerd/io/src/peek.rs b/linkerd/io/src/peek.rs deleted file mode 100644 index b9d97b18a6..0000000000 --- a/linkerd/io/src/peek.rs +++ /dev/null @@ -1,69 +0,0 @@ -use crate::{AsyncRead, AsyncWrite, PrefixedIo}; -use bytes::BytesMut; -use pin_project::pin_project; -use std::future::Future; -use std::io; -use std::pin::Pin; -use std::task::{Context, Poll}; - -/// A future of when some `Peek` fulfills with some bytes. -#[derive(Debug)] -pub struct Peek(Option>); - -#[pin_project] -#[derive(Debug)] -struct Inner { - buf: BytesMut, - - #[pin] - io: T, -} - -pub trait Peekable: AsyncRead + AsyncWrite + Unpin { - fn peek(self, capacity: usize) -> Peek - where - Self: Sized, - { - Peek::with_capacity(capacity, self) - } -} - -impl Peekable for I {} - -// === impl Peek === - -impl Peek { - pub fn with_capacity(capacity: usize, io: T) -> Self - where - Self: Sized + Future, - { - let buf = BytesMut::with_capacity(capacity); - Peek(Some(Inner { buf, io })) - } -} - -impl Future for Peek { - type Output = Result, io::Error>; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.as_mut(); - futures::ready!(this - .0 - .as_mut() - .expect("polled after complete") - .poll_peek(cx))?; - let Inner { buf, io } = this.0.take().expect("polled after complete"); - Poll::Ready(Ok(PrefixedIo::new(buf.freeze(), io))) - } -} - -// === impl Inner === - -impl Inner { - fn poll_peek(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.buf.capacity() == 0 { - return Poll::Ready(Ok(self.buf.len())); - } - crate::poll_read_buf(Pin::new(&mut self.io), cx, &mut self.buf) - } -} diff --git a/linkerd/io/src/prefixed.rs b/linkerd/io/src/prefixed.rs index 2fa93456cb..47efc91db9 100644 --- a/linkerd/io/src/prefixed.rs +++ b/linkerd/io/src/prefixed.rs @@ -15,7 +15,7 @@ pub struct PrefixedIo { io: S, } -impl PrefixedIo { +impl PrefixedIo { pub fn new(prefix: impl Into, io: S) -> Self { let prefix = prefix.into(); Self { prefix, io } diff --git a/linkerd/proxy/http/Cargo.toml b/linkerd/proxy/http/Cargo.toml index 1b793138c1..100da96a18 100644 --- a/linkerd/proxy/http/Cargo.toml +++ b/linkerd/proxy/http/Cargo.toml @@ -11,6 +11,7 @@ This should probably be decomposed into smaller, decoupled crates. """ [dependencies] +async-trait = "0.1" bytes = "0.6" futures = { package = "futures", version = "0.3" } h2 = { git = "https://github.com/hyperium/h2" } @@ -26,6 +27,7 @@ linkerd2-error = { path = "../../error" } linkerd2-http-box = { path = "../../http-box" } linkerd2-identity = { path = "../../identity" } linkerd2-io = { path = "../../io" } +linkerd2-proxy-transport = { path = "../transport" } linkerd2-stack = { path = "../../stack" } linkerd2-timeout = { path = "../../timeout" } rand = "0.7" @@ -35,3 +37,7 @@ tracing = "0.1.22" tracing-futures = { version = "0.2", features = ["std-future"] } try-lock = "0.2" pin-project = "1" + +[dev-dependencies] +tokio-test = "0.3" +tracing-subscriber = "0.2" diff --git a/linkerd/proxy/http/src/detect.rs b/linkerd/proxy/http/src/detect.rs new file mode 100644 index 0000000000..37c835d4aa --- /dev/null +++ b/linkerd/proxy/http/src/detect.rs @@ -0,0 +1,195 @@ +use crate::Version; +use bytes::BytesMut; +use linkerd2_error::Error; +use linkerd2_io::{self as io, AsyncReadExt}; +use linkerd2_proxy_transport::Detect; +use tracing::{debug, trace}; + +const H2_PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; + +#[derive(Clone, Debug, Default)] +pub struct DetectHttp(()); + +#[async_trait::async_trait] +impl Detect for DetectHttp { + type Protocol = Version; + + async fn detect( + &self, + io: &mut I, + buf: &mut BytesMut, + ) -> Result, Error> { + let mut scan_idx = 0; + let mut maybe_h1 = true; + let mut maybe_h2 = true; + + loop { + // Read data from the socket or timeout detection. + trace!( + capacity = buf.capacity(), + scan_idx, + maybe_h1, + maybe_h2, + "Reading" + ); + let sz = io.read_buf(buf).await?; + if sz == 0 { + // No data was read because the socket closed or the + // buffer capacity was exhausted. + debug!(read = buf.len(), "Could not detect protocol"); + return Ok(None); + } + + // HTTP/2 checking is faster because it's a simple string match. If + // we have enough data, check it first. In almost all cases, the + // whole preface should be available from the first read. + if maybe_h2 { + if buf.len() < H2_PREFACE.len() { + // Check the prefix we have already read to see if it looks likely to be HTTP/2. + maybe_h2 = buf[..] == H2_PREFACE[..buf.len()]; + } else { + trace!("Checking H2 preface"); + if &buf[..H2_PREFACE.len()] == H2_PREFACE { + trace!("Matched HTTP/2 prefix"); + return Ok(Some(Version::H2)); + } + + // Not a match. Don't check for an HTTP/2 preface again. + maybe_h2 = false; + } + } + + if maybe_h1 { + // Scan up to the first line ending to determine whether the + // request is HTTP/1.1. HTTP expects \r\n, so we just look for + // any \n to indicate that we can make a determination. + if buf[scan_idx..].contains(&b'\n') { + trace!("Found newline"); + // If the first line looks like an HTTP/2 first line, + // then we almost definitely got a fragmented first + // read. Only try HTTP/1 parsing if it doesn't look like + // HTTP/2. + if !maybe_h2 { + trace!("Parsing HTTP/1 message"); + // If we get to reading headers (and fail), the + // first line looked like an HTTP/1 request; so + // handle the stream as HTTP/1. + if let Ok(_) | Err(httparse::Error::TooManyHeaders) = + httparse::Request::new(&mut [httparse::EMPTY_HEADER; 0]).parse(&buf[..]) + { + trace!("Matched HTTP/1"); + return Ok(Some(Version::Http1)); + } + } + + // We found the EOL and it wasn't an HTTP/1.x request; + // stop scanning and don't scan again. + maybe_h1 = false; + } + + // Advance our scan index to the end of buffer so the next + // iteration starts scanning where we left off. + scan_idx = buf.len() - 1; + } + + if !maybe_h1 && !maybe_h2 { + trace!("Not HTTP"); + return Ok(None); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use bytes::BufMut; + use tokio_test::io; + + const HTTP11_LINE: &'static [u8] = b"GET / HTTP/1.1\r\n"; + const H2_AND_GARBAGE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\ngarbage"; + const GARBAGE: &'static [u8] = + b"garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage garbage"; + + #[tokio::test(flavor = "current_thread")] + async fn h2() { + let _ = tracing_subscriber::fmt().with_test_writer().try_init(); + + for i in 1..H2_PREFACE.len() { + let mut buf = BytesMut::with_capacity(H2_PREFACE.len()); + buf.put(H2_AND_GARBAGE); + debug!(read0 = ?std::str::from_utf8(&H2_AND_GARBAGE[..i]).unwrap()); + debug!(read1 = ?std::str::from_utf8(&H2_AND_GARBAGE[i..]).unwrap()); + let mut buf = BytesMut::with_capacity(1024); + let mut io = io::Builder::new() + .read(&H2_AND_GARBAGE[..i]) + .read(&H2_AND_GARBAGE[i..]) + .build(); + let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); + assert_eq!(kind, Some(Version::H2)); + } + } + + #[tokio::test(flavor = "current_thread")] + async fn http1() { + let _ = tracing_subscriber::fmt().with_test_writer().try_init(); + + for i in 1..HTTP11_LINE.len() { + debug!(read0 = ?std::str::from_utf8(&HTTP11_LINE[..i]).unwrap()); + debug!(read1 = ?std::str::from_utf8(&HTTP11_LINE[i..]).unwrap()); + let mut buf = BytesMut::with_capacity(1024); + let mut io = io::Builder::new() + .read(&HTTP11_LINE[..i]) + .read(&HTTP11_LINE[i..]) + .build(); + let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); + assert_eq!(kind, Some(Version::Http1)); + } + + const REQ: &'static [u8] = + b"GET /foo/bar/bar/blah HTTP/1.1\r\nHost: foob.example.com\r\n\r\n"; + let mut buf = BytesMut::with_capacity(1024); + let mut io = io::Builder::new().read(&REQ).build(); + let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); + assert_eq!(kind, Some(Version::Http1)); + assert_eq!(&buf[..], REQ); + + // Starts with a P, like the h2 preface. + const POST: &'static [u8] = b"POST /foo HTTP/1.1\r\n"; + for i in 1..POST.len() { + let mut buf = BytesMut::with_capacity(1024); + let mut io = io::Builder::new().read(&POST[..i]).read(&POST[i..]).build(); + println!("buf0 = {:?}", &POST[..i]); + println!("buf1 = {:?}", &POST[i..]); + let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); + assert_eq!(kind, Some(Version::Http1)); + assert_eq!(&buf[..], POST); + } + } + + #[tokio::test(flavor = "current_thread")] + async fn unknown() { + let _ = tracing_subscriber::fmt().with_test_writer().try_init(); + + let mut buf = BytesMut::with_capacity(1024); + let mut io = io::Builder::new() + .read(b"foo.bar.blah\r") + .read(b"\nbobo") + .build(); + let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); + assert_eq!(kind, None); + assert_eq!(&buf[..], b"foo.bar.blah\r\nbobo"); + + let mut buf = BytesMut::with_capacity(1024); + let mut io = io::Builder::new().read(GARBAGE).build(); + let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); + assert_eq!(kind, None); + assert_eq!(&buf[..], GARBAGE); + + let mut buf = BytesMut::with_capacity(1024); + let mut io = io::Builder::new().read(&HTTP11_LINE[..14]).build(); + let kind = DetectHttp(()).detect(&mut io, &mut buf).await.unwrap(); + assert_eq!(kind, None); + assert_eq!(&buf[..14], &HTTP11_LINE[..14]); + } +} diff --git a/linkerd/proxy/http/src/lib.rs b/linkerd/proxy/http/src/lib.rs index 88faf49973..7252b02baf 100644 --- a/linkerd/proxy/http/src/lib.rs +++ b/linkerd/proxy/http/src/lib.rs @@ -8,6 +8,7 @@ pub mod add_header; pub mod balance; pub mod client; pub mod client_handle; +mod detect; mod glue; pub mod h1; pub mod h2; @@ -26,6 +27,7 @@ mod version; pub use self::{ client_handle::{ClientHandle, SetClientHandle}, + detect::DetectHttp, glue::{HyperServerSvc, UpgradeBody}, override_authority::CanOverrideAuthority, retain::Retain, diff --git a/linkerd/proxy/http/src/server.rs b/linkerd/proxy/http/src/server.rs index a916684e46..7b79b0e876 100644 --- a/linkerd/proxy/http/src/server.rs +++ b/linkerd/proxy/http/src/server.rs @@ -3,7 +3,7 @@ use crate::{ client_handle::SetClientHandle, glue::{HyperServerSvc, UpgradeBody}, h2::Settings as H2Settings, - trace, upgrade, Version as HttpVersion, + trace, upgrade, Version, }; use futures::prelude::*; use linkerd2_drain as drain; @@ -16,7 +16,7 @@ use std::{ task::{Context, Poll}, }; use tower::{util::ServiceExt, Service}; -use tracing::{debug, trace}; +use tracing::debug; type Server = hyper::server::conn::Http; @@ -28,24 +28,15 @@ pub struct NewServeHttp { drain: drain::Watch, } -/// Accepts HTTP connections. -/// -/// The server accepts TCP connections with their detected protocol. If the -/// protocol is known to be HTTP, a server is built with a new HTTP service -/// (built using the `H`-typed NewService). -/// -/// Otherwise, the `F` type forwarding service is used to handle the TCP -/// connection. -#[derive(Debug)] -pub struct ServeHttp, H: NewService<(HttpVersion, T)>> { - target: T, - new_tcp: F, - tcp: Option, - new_http: H, - http1: Option, - h2: Option, - server: hyper::server::conn::Http, - drain: drain::Watch, +#[derive(Clone, Debug)] +pub enum ServeHttp { + Opaque(F, drain::Watch), + Http { + version: Version, + service: H, + server: Server, + drain: drain::Watch, + }, } // === impl NewServeHttp === @@ -77,55 +68,43 @@ impl NewServeHttp { } } -impl NewService for NewServeHttp +impl NewService<(Option, T)> for NewServeHttp where F: NewService + Clone, - H: NewService<(HttpVersion, T)> + Clone, + H: NewService<(Version, T)> + Clone, { - type Service = ServeHttp; - - fn new_service(&mut self, target: T) -> Self::Service { - ServeHttp::new( - target, - self.server.clone(), - self.http.clone(), - self.tcp.clone(), - self.drain.clone(), - ) + type Service = ServeHttp; + + fn new_service(&mut self, (v, target): (Option, T)) -> Self::Service { + match v { + Some(version) => { + debug!(?version, "Creating HTTP service"); + let service = self.http.new_service((version, target)); + ServeHttp::Http { + version, + service, + server: self.server.clone(), + drain: self.drain.clone(), + } + } + None => { + debug!("Creating TCP service"); + let svc = self.tcp.new_service(target); + ServeHttp::Opaque(svc, self.drain.clone()) + } + } } } // === impl ServeHttp === -impl ServeHttp +impl Service> for ServeHttp where - F: NewService, - H: NewService<(HttpVersion, T)>, -{ - pub fn new(target: T, server: Server, new_http: H, new_tcp: F, drain: drain::Watch) -> Self { - Self { - target, - server, - new_tcp, - tcp: None, - new_http, - http1: None, - h2: None, - drain, - } - } -} - -impl Service> for ServeHttp -where - T: Clone, I: io::AsyncRead + io::AsyncWrite + PeerAddr + Send + Unpin + 'static, - F: NewService + Clone, - FSvc: tower::Service, Response = ()> + Clone + Send + 'static, - FSvc::Error: Into, - FSvc::Future: Send + 'static, - H: NewService<(HttpVersion, T), Service = HSvc> + Clone, - HSvc: Service< + F: tower::Service, Response = ()> + Clone + Send + 'static, + F::Error: Into, + F::Future: Send + 'static, + H: Service< http::Request, Response = http::Response, Error = Error, @@ -133,7 +112,7 @@ where + Unpin + Send + 'static, - HSvc::Future: Send + 'static, + H::Future: Send + 'static, { type Response = (); type Error = Error; @@ -145,120 +124,72 @@ where } fn call(&mut self, io: PrefixedIo) -> Self::Future { - let version = HttpVersion::from_prefix(io.prefix()); - match version { - Some(HttpVersion::Http1) => { - debug!("Handling as HTTP"); - let http1 = if let Some(svc) = self.http1.clone() { - trace!("HTTP service already exists"); - svc - } else { - trace!("Building new HTTP service"); - let svc = self - .new_http - .new_service((HttpVersion::Http1, self.target.clone())); - self.http1 = Some(svc.clone()); - svc - }; - - let mut server = self.server.clone(); - let drain = self.drain.clone(); - Box::pin(async move { - let (svc, closed) = SetClientHandle::new(io.peer_addr()?, http1); - - let mut conn = server - .http1_only(true) - .serve_connection( - io, - // Enable support for HTTP upgrades (CONNECT and websockets). - upgrade::Service::new(svc, drain.clone()), - ) - .with_upgrades(); - - tokio::select! { - res = &mut conn => { - debug!(?res, "The client is shutting down the connection"); - res? - } - shutdown = drain.signal() => { - debug!("The process is shutting down the connection"); - Pin::new(&mut conn).graceful_shutdown(); - shutdown.release_after(conn).await?; - } - () = closed => { - debug!("The stack is tearing down the connection"); - Pin::new(&mut conn).graceful_shutdown(); - conn.await?; + match self.clone() { + Self::Http { + version, + service, + drain, + mut server, + } => Box::pin(async move { + debug!(?version, "Handling as HTTP"); + let (svc, closed) = SetClientHandle::new(io.peer_addr()?, service); + match version { + Version::Http1 => { + // Enable support for HTTP upgrades (CONNECT and websockets). + let mut conn = server + .http1_only(true) + .serve_connection(io, upgrade::Service::new(svc, drain.clone())) + .with_upgrades(); + + tokio::select! { + res = &mut conn => { + debug!(?res, "The client is shutting down the connection"); + res? + } + shutdown = drain.signal() => { + debug!("The process is shutting down the connection"); + Pin::new(&mut conn).graceful_shutdown(); + shutdown.release_after(conn).await?; + } + () = closed => { + debug!("The stack is tearing down the connection"); + Pin::new(&mut conn).graceful_shutdown(); + conn.await?; + } } } - - Ok(()) - }) - } - - Some(HttpVersion::H2) => { - debug!("Handling as H2"); - let h2 = if let Some(svc) = self.h2.clone() { - trace!("H2 service already exists"); - svc - } else { - trace!("Building new H2 service"); - let svc = self - .new_http - .new_service((HttpVersion::H2, self.target.clone())); - self.h2 = Some(svc.clone()); - svc - }; - - let mut server = self.server.clone(); - let drain = self.drain.clone(); - Box::pin(async move { - let (svc, closed) = SetClientHandle::new(io.peer_addr()?, h2); - - let mut conn = server - .http2_only(true) - .serve_connection(io, HyperServerSvc::new(svc)); - - tokio::select! { - res = &mut conn => { - debug!(?res, "The client is shutting down the connection"); - res? - } - shutdown = drain.signal() => { - debug!("The process is shutting down the connection"); - Pin::new(&mut conn).graceful_shutdown(); - shutdown.release_after(conn).await?; - } - () = closed => { - debug!("The stack is tearing down the connection"); - Pin::new(&mut conn).graceful_shutdown(); - conn.await?; + Version::H2 => { + let mut conn = server + .http2_only(true) + .serve_connection(io, HyperServerSvc::new(svc)); + + tokio::select! { + res = &mut conn => { + debug!(?res, "The client is shutting down the connection"); + res? + } + shutdown = drain.signal() => { + debug!("The process is shutting down the connection"); + Pin::new(&mut conn).graceful_shutdown(); + shutdown.release_after(conn).await?; + } + () = closed => { + debug!("The stack is tearing down the connection"); + Pin::new(&mut conn).graceful_shutdown(); + conn.await?; + } } } + } - Ok(()) - }) - } - - None => { + Ok(()) + }), + Self::Opaque(tcp, drain) => Box::pin({ debug!("Forwarding TCP"); - let tcp = if let Some(svc) = self.tcp.clone() { - trace!("TCP service already exists"); - svc - } else { - trace!("Building new TCP service"); - let svc = self.new_tcp.new_service(self.target.clone()); - self.tcp = Some(svc.clone()); - svc - }; - - Box::pin( - self.drain - .clone() - .ignore_signal() - .release_after(tcp.oneshot(io).err_into::()), - ) - } + drain + .ignore_signal() + .release_after(tcp.oneshot(io).err_into::()) + }), } } } diff --git a/linkerd/proxy/http/src/version.rs b/linkerd/proxy/http/src/version.rs index 364160cefb..64f01a17a5 100644 --- a/linkerd/proxy/http/src/version.rs +++ b/linkerd/proxy/http/src/version.rs @@ -1,5 +1,3 @@ -use tracing::{debug, trace}; - #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub enum Version { Http1, @@ -20,48 +18,6 @@ impl std::convert::TryFrom for Version { } } -impl Version { - pub const DETECT_BUFFER_CAPACITY: usize = 8192; - const H2_PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; - - /// Tries to detect a known protocol in the peeked bytes. - /// - /// If no protocol can be determined, returns `None`. - pub fn from_prefix(bytes: &[u8]) -> Option { - // http2 is easiest to detect - if bytes.len() >= Self::H2_PREFACE.len() { - if &bytes[..Self::H2_PREFACE.len()] == Self::H2_PREFACE { - trace!("Detected H2"); - return Some(Self::H2); - } - } - - // http1 can have a really long first line, but if the bytes so far - // look like http1, we'll assume it is. a different protocol - // should look different in the first few bytes - - let mut headers = [httparse::EMPTY_HEADER; 0]; - let mut req = httparse::Request::new(&mut headers); - match req.parse(bytes) { - // Ok(Complete) or Ok(Partial) both mean it looks like HTTP1! - // - // If we got past the first line, we'll see TooManyHeaders, - // because we passed an array of 0 headers to parse into. That's fine! - // We didn't want to keep parsing headers, just validate that - // the first line is HTTP1. - Ok(_) | Err(httparse::Error::TooManyHeaders) => { - trace!("Detected H1"); - return Some(Self::Http1); - } - _ => {} - } - - debug!("Not HTTP"); - trace!(?bytes); - None - } -} - // A convenience for tracing contexts. impl std::fmt::Display for Version { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -79,21 +35,3 @@ impl std::fmt::Display for Unsupported { } impl std::error::Error for Unsupported {} - -#[cfg(test)] -#[test] -fn from_prefix() { - assert_eq!(Version::from_prefix(Version::H2_PREFACE), Some(Version::H2)); - assert_eq!( - Version::from_prefix("GET /foo/bar/bah/baz HTTP/1.1".as_ref()), - Some(Version::Http1) - ); - assert_eq!( - Version::from_prefix("GET /foo".as_ref()), - Some(Version::Http1) - ); - assert_eq!( - Version::from_prefix("GET /foo/barbasdklfja\n".as_ref()), - None - ); -} diff --git a/linkerd/proxy/transport/Cargo.toml b/linkerd/proxy/transport/Cargo.toml index cb28f04c3a..0f1b9d9305 100644 --- a/linkerd/proxy/transport/Cargo.toml +++ b/linkerd/proxy/transport/Cargo.toml @@ -15,6 +15,7 @@ mock-orig-dst = [] [dependencies] async-stream = "0.2.1" +async-trait = "0.1" bytes = "0.6" futures = "0.3" indexmap = "1.0.0" @@ -27,13 +28,13 @@ linkerd2-io = { path = "../../io" } linkerd2-metrics = { path = "../../metrics" } linkerd2-stack = { path = "../../stack" } rustls = "0.18" -tokio = { version = "0.3", features = ["net", "io-util"]} +tokio = { version = "0.3", features = ["io-util", "net", "time"]} tokio-rustls = "0.20" +tokio-util = { version = "0.5", features = ["compat"]} tracing = "0.1.22" webpki = "0.21" untrusted = "0.7" pin-project = "0.4" -tokio-util = { version = "0.5", features = ["compat"]} socket2 = "0.3" [dependencies.tower] diff --git a/linkerd/proxy/transport/src/detect/mod.rs b/linkerd/proxy/transport/src/detect/mod.rs new file mode 100644 index 0000000000..89ef645f18 --- /dev/null +++ b/linkerd/proxy/transport/src/detect/mod.rs @@ -0,0 +1,131 @@ +mod timeout; + +pub use self::timeout::{DetectTimeout, DetectTimeoutError}; +use crate::io; +use bytes::BytesMut; +use futures::prelude::*; +use linkerd2_error::Error; +use linkerd2_stack::{layer, NewService}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; +use tokio::time; +use tower::util::ServiceExt; +use tracing::{debug, trace}; + +#[async_trait::async_trait] +pub trait Detect: Clone + Send + Sync + 'static { + type Protocol: Send; + + async fn detect( + &self, + io: &mut I, + buf: &mut BytesMut, + ) -> Result, Error>; +} + +#[derive(Copy, Clone)] +pub struct NewDetectService { + new_accept: N, + detect: D, + capacity: usize, +} + +#[derive(Copy, Clone)] +pub struct DetectService { + target: T, + new_accept: N, + detect: D, + capacity: usize, +} + +// === impl NewDetectService === + +impl NewDetectService { + const BUFFER_CAPACITY: usize = 1024; + + pub fn new(new_accept: N, detect: D) -> Self { + Self { + detect, + new_accept, + capacity: Self::BUFFER_CAPACITY, + } + } + + pub fn layer(detect: D) -> impl layer::Layer + Clone { + layer::mk(move |new| Self::new(new, detect.clone())) + } +} + +impl NewService for NewDetectService { + type Service = DetectService; + + fn new_service(&mut self, target: T) -> DetectService { + DetectService { + target, + new_accept: self.new_accept.clone(), + detect: self.detect.clone(), + capacity: self.capacity, + } + } +} + +// === impl DetectService === + +impl tower::Service for DetectService +where + T: Clone + Send + 'static, + I: io::AsyncRead + Send + Unpin + 'static, + D: Detect, + D::Protocol: std::fmt::Debug, + N: NewService<(Option, T), Service = S> + Clone + Send + 'static, + S: tower::Service, Response = ()> + Send, + S::Error: Into, + S::Future: Send, +{ + type Response = (); + type Error = Error; + type Future = Pin> + Send + 'static>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(().into())) + } + + fn call(&mut self, mut io: I) -> Self::Future { + let mut new_accept = self.new_accept.clone(); + let mut buf = BytesMut::with_capacity(self.capacity); + let detect = self.detect.clone(); + let target = self.target.clone(); + Box::pin(async move { + trace!("Starting protocol detection"); + let t0 = time::Instant::now(); + let protocol = detect.detect(&mut io, &mut buf).await?; + debug!( + ?protocol, + elapsed = ?(time::Instant::now() - t0), + "Detected" + ); + + let mut accept = new_accept + .new_service((protocol, target)) + .ready_oneshot() + .err_into::() + .await?; + + trace!("Dispatching connection"); + accept + .call(io::PrefixedIo::new(buf.freeze(), io)) + .err_into::() + .await?; + + trace!("Connection completed"); + // Hold the service until it's done being used so that cache + // idleness is reset. + drop(accept); + + Ok(()) + }) + } +} diff --git a/linkerd/proxy/transport/src/detect/timeout.rs b/linkerd/proxy/transport/src/detect/timeout.rs new file mode 100644 index 0000000000..4254593ccb --- /dev/null +++ b/linkerd/proxy/transport/src/detect/timeout.rs @@ -0,0 +1,67 @@ +use super::Detect; +use crate::io; +use bytes::BytesMut; +use futures::prelude::*; +use linkerd2_error::Error; +use tokio::time; + +#[derive(Copy, Clone, Debug)] +pub struct DetectTimeout { + inner: D, + timeout: time::Duration, +} + +#[derive(Debug)] +pub struct DetectTimeoutError { + bytes: usize, + elapsed: time::Duration, +} + +// === impl DetectTimeout === + +impl DetectTimeout { + pub fn new(timeout: time::Duration, inner: D) -> Self { + Self { inner, timeout } + } +} + +#[async_trait::async_trait] +impl Detect for DetectTimeout +where + D: Detect, + D::Protocol: std::fmt::Debug, +{ + type Protocol = D::Protocol; + + async fn detect( + &self, + io: &mut I, + buf: &mut BytesMut, + ) -> Result, Error> { + let t0 = time::Instant::now(); + let timeout = time::sleep(self.timeout); + let detect = self.inner.detect(io, buf); + futures::select_biased! { + res = detect.fuse() => res, + _ = timeout.fuse() => { + let bytes = buf.len(); + let elapsed = time::Instant::now() - t0; + Err(DetectTimeoutError { bytes, elapsed }.into()) + } + } + } +} + +// === impl DetectTimeout === + +impl std::fmt::Display for DetectTimeoutError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Protocol detection timeout after: {}B after {:?}", + self.bytes, self.elapsed + ) + } +} + +impl std::error::Error for DetectTimeoutError {} diff --git a/linkerd/proxy/transport/src/lib.rs b/linkerd/proxy/transport/src/lib.rs index 1a8caff4c3..57f3c86373 100644 --- a/linkerd/proxy/transport/src/lib.rs +++ b/linkerd/proxy/transport/src/lib.rs @@ -5,17 +5,17 @@ use std::time::Duration; use tokio::net::TcpStream; pub mod connect; +pub mod detect; pub use linkerd2_io as io; pub mod listen; pub mod metrics; -pub mod prefix; pub mod tls; pub use self::{ connect::Connect, + detect::{Detect, DetectService, NewDetectService}, io::BoxedIo, listen::{Bind, DefaultOrigDstAddr, NoOrigDstAddr, OrigDstAddr}, - prefix::Prefix, }; // Misc. diff --git a/linkerd/proxy/transport/src/prefix.rs b/linkerd/proxy/transport/src/prefix.rs deleted file mode 100644 index 8ce14162dc..0000000000 --- a/linkerd/proxy/transport/src/prefix.rs +++ /dev/null @@ -1,76 +0,0 @@ -use crate::io; -use futures::prelude::*; -use linkerd2_error::Error; -use linkerd2_stack::layer; -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, - time::Duration, -}; -use tokio::time; -use tower::util::ServiceExt; -use tracing::{debug, trace}; - -#[derive(Copy, Clone)] -pub struct Prefix { - inner: S, - capacity: usize, - timeout: Duration, -} - -#[derive(Debug)] -pub struct ReadTimeout(()); - -impl Prefix { - pub fn new(inner: S, capacity: usize, timeout: Duration) -> Self { - Self { - inner, - capacity, - timeout, - } - } - - pub fn layer( - capacity: usize, - timeout: Duration, - ) -> impl layer::Layer> + Clone { - layer::mk(move |inner| Self::new(inner, capacity, timeout)) - } -} - -impl tower::Service for Prefix -where - I: io::Peekable + Send + 'static, - S: tower::Service, Response = ()> + Clone + Send + 'static, - S::Error: Into, - S::Future: Send, -{ - type Response = (); - type Error = Error; - type Future = Pin> + Send + 'static>>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(().into())) - } - - fn call(&mut self, io: I) -> Self::Future { - debug!(capacity = self.capacity, "Buffering prefix"); - let accept = self.inner.clone(); - let peek = time::timeout(self.timeout, io.peek(self.capacity)).map_err(|_| ReadTimeout(())); - Box::pin(async move { - let io = peek.await??; - trace!(read = %io.prefix().len()); - accept.oneshot(io).err_into::().await?; - Ok(()) - }) - } -} - -impl std::fmt::Display for ReadTimeout { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Timed out while reading client stream prefix") - } -} - -impl std::error::Error for ReadTimeout {}