diff --git a/linkerd/app/core/src/svc.rs b/linkerd/app/core/src/svc.rs index 94e143fbfb..9fccdbc766 100644 --- a/linkerd/app/core/src/svc.rs +++ b/linkerd/app/core/src/svc.rs @@ -7,8 +7,8 @@ use linkerd_exp_backoff::{ExponentialBackoff, ExponentialBackoffStream}; pub use linkerd_reconnect::NewReconnect; pub use linkerd_stack::{ self as stack, layer, ArcNewService, BoxCloneService, BoxService, BoxServiceLayer, Either, - ExtractParam, Fail, FailFast, Filter, InsertParam, MapErr, MapTargetLayer, NewRouter, - NewService, Param, Predicate, UnwrapOr, + ExtractParam, Fail, FailFast, Filter, InsertParam, MakeConnection, MapErr, MapTargetLayer, + NewRouter, NewService, Param, Predicate, UnwrapOr, }; pub use linkerd_stack_tracing::{NewInstrument, NewInstrumentLayer}; use std::{ diff --git a/linkerd/app/gateway/src/lib.rs b/linkerd/app/gateway/src/lib.rs index dfb9be0484..5e4081b925 100644 --- a/linkerd/app/gateway/src/lib.rs +++ b/linkerd/app/gateway/src/lib.rs @@ -17,7 +17,7 @@ use linkerd_app_core::{ }, svc::{self, Param}, tls, - transport::{ClientAddr, OrigDstAddr, Remote}, + transport::{ClientAddr, Local, OrigDstAddr, Remote}, transport_header::SessionProtocol, Error, Infallible, NameAddr, NameMatch, }; @@ -70,9 +70,8 @@ pub fn stack( where I: io::AsyncRead + io::AsyncWrite + io::PeerAddr + fmt::Debug + Send + Sync + Unpin + 'static, O: Clone + Send + Sync + Unpin + 'static, - O: svc::Service, - O::Response: - io::AsyncRead + io::AsyncWrite + tls::HasNegotiatedProtocol + Send + Unpin + 'static, + O: svc::MakeConnection, Error = io::Error>, + O::Connection: Send + Unpin, O::Future: Send + Unpin + 'static, P: profiles::GetProfile + Clone + Send + Sync + Unpin + 'static, P::Future: Send + 'static, diff --git a/linkerd/app/inbound/src/http/router.rs b/linkerd/app/inbound/src/http/router.rs index 21b05ae6e6..38535ca7fe 100644 --- a/linkerd/app/inbound/src/http/router.rs +++ b/linkerd/app/inbound/src/http/router.rs @@ -1,6 +1,6 @@ use crate::{policy, stack_labels, Inbound}; use linkerd_app_core::{ - classify, errors, http_tracing, io, metrics, + classify, errors, http_tracing, metrics, profiles::{self, DiscoveryRejected}, proxy::{http, tap}, svc::{self, ExtractParam, Param}, @@ -84,9 +84,9 @@ impl Inbound { P: profiles::GetProfile + Clone + Send + Sync + 'static, P::Future: Send, P::Error: Send, - C: svc::Service + Clone + Send + Sync + Unpin + 'static, - C::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin + 'static, - C::Error: Into, + C: svc::MakeConnection + Clone + Send + Sync + Unpin + 'static, + C::Connection: Send + Unpin, + C::Metadata: Send, C::Future: Send, { self.map_stack(|config, rt, connect| { @@ -94,12 +94,17 @@ impl Inbound { // Creates HTTP clients for each inbound port & HTTP settings. let http = connect + .push(svc::layer::mk(|inner: C| inner.into_service())) + .check_service::() .push(svc::stack::BoxFuture::layer()) + .check_service::() .push(transport::metrics::Client::layer(rt.metrics.proxy.transport.clone())) + .check_service::() .push(http::client::layer( config.proxy.connect.h1_settings, config.proxy.connect.h2_settings, )) + .check_service::() .push_on_service(svc::MapErr::layer(Into::into)) .into_new_service() .push_new_reconnect(config.proxy.connect.backoff) diff --git a/linkerd/app/inbound/src/lib.rs b/linkerd/app/inbound/src/lib.rs index 30553b9008..5cdacd5d7a 100644 --- a/linkerd/app/inbound/src/lib.rs +++ b/linkerd/app/inbound/src/lib.rs @@ -155,9 +155,10 @@ impl Inbound<()> { self, proxy_port: u16, ) -> Inbound< - impl svc::Service< + impl svc::MakeConnection< T, - Response = impl io::AsyncRead + io::AsyncWrite + Send, + Connection = impl Send + Unpin, + Metadata = impl Send + Unpin, Error = Error, Future = impl Send, > + Clone, @@ -214,9 +215,9 @@ impl Inbound { T: svc::Param + Clone + Send + 'static, I: io::AsyncRead + io::AsyncWrite, I: Debug + Send + Unpin + 'static, - S: svc::Service + Clone + Send + Sync + Unpin + 'static, - S::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin + 'static, - S::Error: Into, + S: svc::MakeConnection + Clone + Send + Sync + Unpin + 'static, + S::Connection: Send + Unpin, + S::Metadata: Send + Unpin, S::Future: Send, { self.map_stack(|_, rt, connect| { @@ -224,6 +225,7 @@ impl Inbound { .push(transport::metrics::Client::layer( rt.metrics.proxy.transport.clone(), )) + .push(svc::stack::WithoutConnectionMetadata::layer()) .push_make_thunk() .push_on_service( svc::layers() diff --git a/linkerd/app/outbound/src/endpoint.rs b/linkerd/app/outbound/src/endpoint.rs index e46e3ce62c..f987bfdde1 100644 --- a/linkerd/app/outbound/src/endpoint.rs +++ b/linkerd/app/outbound/src/endpoint.rs @@ -209,10 +209,10 @@ impl Outbound { pub fn push_endpoint(self) -> Outbound> where Self: Clone + 'static, - S: svc::Service + Clone + Send + Sync + Unpin + 'static, - S::Response: - tls::HasNegotiatedProtocol + io::AsyncRead + io::AsyncWrite + Send + Unpin + 'static, - S::Future: Send + Unpin, + S: svc::MakeConnection, Error = io::Error>, + S: Clone + Send + Sync + Unpin + 'static, + S::Connection: Send + Unpin + 'static, + S::Future: Send, I: io::AsyncRead + io::AsyncWrite + io::PeerAddr, I: fmt::Debug + Send + Sync + Unpin + 'static, { diff --git a/linkerd/app/outbound/src/http/endpoint.rs b/linkerd/app/outbound/src/http/endpoint.rs index 80d8ff4bc2..698afe3084 100644 --- a/linkerd/app/outbound/src/http/endpoint.rs +++ b/linkerd/app/outbound/src/http/endpoint.rs @@ -6,7 +6,6 @@ use linkerd_app_core::{ svc::{self, ExtractParam}, tls, Error, Result, CANONICAL_DST_HEADER, }; -use tokio::io; #[derive(Copy, Clone, Debug)] struct ClientRescue { @@ -24,9 +23,9 @@ impl Outbound { + tap::Inspect, B: http::HttpBody + std::fmt::Debug + Default + Send + 'static, B::Data: Send + 'static, - C: svc::Service + Clone + Send + Sync + Unpin + 'static, - C::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin, - C::Error: Into, + C: svc::MakeConnection + Clone + Send + Sync + Unpin + 'static, + C::Connection: Send + Unpin, + C::Metadata: Send + Unpin, C::Future: Send + Unpin + 'static, { self.map_stack(|config, rt, connect| { diff --git a/linkerd/app/outbound/src/lib.rs b/linkerd/app/outbound/src/lib.rs index 848e97a8d3..e8ee1210b5 100644 --- a/linkerd/app/outbound/src/lib.rs +++ b/linkerd/app/outbound/src/lib.rs @@ -84,6 +84,8 @@ pub struct Accept

{ pub protocol: P, } +pub type ConnectMeta = tls::ConnectMeta>; + // === impl Outbound === impl Outbound<()> { diff --git a/linkerd/app/outbound/src/logical.rs b/linkerd/app/outbound/src/logical.rs index 8daa362a2c..f62bbcd93a 100644 --- a/linkerd/app/outbound/src/logical.rs +++ b/linkerd/app/outbound/src/logical.rs @@ -3,7 +3,9 @@ pub use linkerd_app_core::proxy::api_resolve::ConcreteAddr; use linkerd_app_core::{ io, profiles, proxy::{api_resolve::Metadata, core::Resolve}, - svc, tls, Addr, Error, + svc, + transport::{ClientAddr, Local}, + Addr, Error, }; pub use profiles::LogicalAddr; use std::fmt; @@ -118,9 +120,8 @@ impl Outbound { where Self: Clone + 'static, C: Clone + Send + Sync + Unpin + 'static, - C: svc::Service, - C::Response: - tls::HasNegotiatedProtocol + io::AsyncRead + io::AsyncWrite + Send + Unpin + 'static, + C: svc::MakeConnection, Error = io::Error>, + C::Connection: Send + Unpin, C::Future: Send + Unpin, R: Clone + Send + 'static, R: Resolve + Sync, diff --git a/linkerd/app/outbound/src/tcp/connect.rs b/linkerd/app/outbound/src/tcp/connect.rs index 84ec8dc503..0fe5ec7e38 100644 --- a/linkerd/app/outbound/src/tcp/connect.rs +++ b/linkerd/app/outbound/src/tcp/connect.rs @@ -1,11 +1,11 @@ use super::opaque_transport::{self, OpaqueTransport}; -use crate::Outbound; +use crate::{ConnectMeta, Outbound}; use futures::future; use linkerd_app_core::{ io, proxy::http, svc, tls, - transport::{self, ConnectTcp, Remote, ServerAddr}, + transport::{self, ClientAddr, ConnectTcp, Local, Remote, ServerAddr}, transport_header::SessionProtocol, Error, }; @@ -36,9 +36,10 @@ impl Outbound { pub fn push_tcp_endpoint( self, ) -> Outbound< - impl svc::Service< + impl svc::MakeConnection< T, - Response = impl io::AsyncRead + io::AsyncWrite + Send + Unpin, + Connection = impl Send + Unpin, + Metadata = ConnectMeta, Error = Error, Future = impl Send, > + Clone, @@ -50,9 +51,10 @@ impl Outbound { + svc::Param> + svc::Param> + svc::Param, - C: svc::Service + Clone + Send + 'static, - C::Response: tls::HasNegotiatedProtocol, - C::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin + 'static, + C: svc::MakeConnection, Error = io::Error>, + C: Clone + Send + 'static, + C::Connection: Send + Unpin, + C::Metadata: Send + Unpin, C::Future: Send + 'static, { self.map_stack(|config, rt, connect| { @@ -85,13 +87,14 @@ impl Outbound { where T: Clone + Send + 'static, I: io::AsyncRead + io::AsyncWrite + io::PeerAddr + std::fmt::Debug + Send + Unpin + 'static, - C: svc::Service + Clone + Send + Sync + 'static, - C::Response: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin, - C::Error: Into, + C: svc::MakeConnection + Clone + Send + Sync + 'static, + C::Connection: Send + Unpin, + C::Metadata: Send + Unpin, C::Future: Send, { self.map_stack(|_, _, conn| { - conn.push_make_thunk() + conn.push(svc::stack::WithoutConnectionMetadata::layer()) + .push_make_thunk() .push_on_service(super::Forward::layer()) .instrument(|_: &_| debug_span!("tcp.forward")) .push(svc::ArcNewService::layer()) @@ -166,6 +169,7 @@ mod tests { use crate::{ svc::{self, NewService, ServiceExt}, test_util::*, + transport::{ClientAddr, Local}, }; use std::net::SocketAddr; @@ -180,7 +184,10 @@ mod tests { assert_eq!(a, addr); let mut io = support::io(); io.write(b"hello").read(b"world"); - future::ok::<_, support::io::Error>(io.build()) + future::ok::<_, support::io::Error>(( + io.build(), + Local(ClientAddr(([0, 0, 0, 0], 0).into())), + )) })) .push_tcp_forward() .into_inner(); diff --git a/linkerd/app/outbound/src/tcp/logical.rs b/linkerd/app/outbound/src/tcp/logical.rs index a84808f66c..c823b27c66 100644 --- a/linkerd/app/outbound/src/tcp/logical.rs +++ b/linkerd/app/outbound/src/tcp/logical.rs @@ -12,13 +12,7 @@ use linkerd_app_core::{ }; use tracing::debug_span; -impl Outbound -where - C: svc::Service + Clone + Send + 'static, - C::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin, - C::Error: Into, - C::Future: Send, -{ +impl Outbound { /// Constructs a TCP load balancer. pub fn push_tcp_logical( self, @@ -30,6 +24,11 @@ where >, > where + C: svc::MakeConnection + Clone + Send + 'static, + C::Connection: Send + Unpin, + C::Metadata: Send + Unpin, + C::Future: Send, + C: Send + Sync + 'static, I: io::AsyncRead + io::AsyncWrite + std::fmt::Debug + Send + Unpin + 'static, R: Resolve + Clone @@ -38,7 +37,6 @@ where + 'static, R::Resolution: Send, R::Future: Send + Unpin, - C: Send + Sync + 'static, { self.map_stack(|config, rt, connect| { let config::ProxyConfig { @@ -63,6 +61,7 @@ where .into_inner(); connect + .push(svc::stack::WithoutConnectionMetadata::layer()) .push_make_thunk() .instrument(|t: &Endpoint| { debug_span!( @@ -159,7 +158,8 @@ mod tests { assert_eq!(*ep.addr.as_ref(), ep_addr); let mut io = support::io(); io.write(b"hola").read(b"mundo"); - future::ok::<_, support::io::Error>(io.build()) + let local = Local(ClientAddr(([0, 0, 0, 0], 4444).into())); + future::ok::<_, support::io::Error>((io.build(), local)) })) .push_tcp_logical(resolve) .into_inner(); @@ -225,13 +225,15 @@ mod tests { tracing::debug!(%addr, "writing ep0"); let mut io = support::io(); io.write(b"who r u?").read(b"ep0"); - future::ok::<_, support::io::Error>(io.build()) + let local = Local(ClientAddr(([0, 0, 0, 0], 4444).into())); + future::ok::<_, support::io::Error>((io.build(), local)) } Remote(ServerAddr(addr)) if addr == ep1_addr => { tracing::debug!(%addr, "writing ep1"); let mut io = support::io(); io.write(b"who r u?").read(b"ep1"); - future::ok::<_, support::io::Error>(io.build()) + let local = Local(ClientAddr(([0, 0, 0, 0], 4444).into())); + future::ok::<_, support::io::Error>((io.build(), local)) } addr => unreachable!("unexpected endpoint: {}", addr), })) diff --git a/linkerd/app/outbound/src/tcp/opaque_transport.rs b/linkerd/app/outbound/src/tcp/opaque_transport.rs index ed929b17aa..873b236c1c 100644 --- a/linkerd/app/outbound/src/tcp/opaque_transport.rs +++ b/linkerd/app/outbound/src/tcp/opaque_transport.rs @@ -1,12 +1,12 @@ -use crate::tcp::Connect; +use crate::{tcp::Connect, ConnectMeta}; use futures::prelude::*; use linkerd_app_core::{ - dns, io, + dns, proxy::http, svc, tls, transport::{Remote, ServerAddr}, transport_header::{SessionProtocol, TransportHeader, PROTOCOL}, - Error, + Conditional, Error, Result, }; use std::{ future::Future, @@ -34,12 +34,12 @@ impl OpaqueTransport { /// Determines whether the connection has negotiated support for the /// transport header. #[inline] - fn header_negotiated(io: &I) -> bool { - if let Some(tls::NegotiatedProtocolRef(protocol)) = io.negotiated_protocol() { - protocol == PROTOCOL - } else { - false + fn header_negotiated(meta: &ConnectMeta) -> bool { + if let Conditional::Some(Some(np)) = meta.tls.as_ref() { + let tls::NegotiatedProtocolRef(protocol) = np.as_ref(); + return protocol == PROTOCOL; } + false } } @@ -50,14 +50,14 @@ where + svc::Param> + svc::Param> + svc::Param>, - S: svc::Service + Send + 'static, - S::Error: Into, - S::Response: io::AsyncWrite + tls::HasNegotiatedProtocol + Send + Unpin, + S: svc::MakeConnection + Send + 'static, + S::Connection: Send + Unpin, S::Future: Send + 'static, { - type Response = S::Response; + type Response = (S::Connection, S::Metadata); type Error = Error; - type Future = Pin> + Send + 'static>>; + type Future = + Pin> + Send + 'static>>; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -72,7 +72,7 @@ where addr: ep.param(), tls, }; - return Box::pin(self.inner.call(target).err_into::()); + return Box::pin(self.inner.connect(target).err_into::()); } // Configure the target port from the endpoint. In opaque cases, this is @@ -110,16 +110,16 @@ where let protocol: Option = ep.param(); - let connect = self.inner.call(Connect { + let connect = self.inner.connect(Connect { addr: Remote(ServerAddr((addr.ip(), connect_port).into())), tls, }); Box::pin(async move { - let mut io = connect.await.map_err(Into::into)?; + let (mut io, meta) = connect.await.map_err(Into::into)?; // If transport header support has been negotiated via ALPN, encode // the header and then return the socket. - if Self::header_negotiated(&io) { + if Self::header_negotiated(&meta) { let header = TransportHeader { port: target_port, name, @@ -132,7 +132,7 @@ where trace!("Connection does not expect a transport header"); } - Ok(io) + Ok((io, meta)) }) } } @@ -147,11 +147,9 @@ mod test { io::{self, AsyncWriteExt}, proxy::api_resolve::{Metadata, ProtocolHint}, tls, - transport::{Remote, ServerAddr}, + transport::{ClientAddr, Local}, transport_header::TransportHeader, }; - use pin_project::pin_project; - use std::task::Context; use tower::util::{service_fn, ServiceExt}; fn ep(metadata: Metadata) -> Endpoint<()> { @@ -173,13 +171,15 @@ mod test { let Remote(ServerAddr(sa)) = ep.addr; assert_eq!(sa.port(), 4321); assert!(ep.tls.is_none()); - future::ready(Ok::<_, io::Error>(Io { - io: tokio_test::io::Builder::new().write(b"hello").build(), - alpn: None, - })) + let io = tokio_test::io::Builder::new().write(b"hello").build(); + let meta = tls::ConnectMeta { + socket: Local(ClientAddr(([0, 0, 0, 0], 0).into())), + tls: Conditional::Some(None), + }; + future::ready(Ok::<_, io::Error>((io, meta))) }), }; - let mut io = svc + let (mut io, _meta) = svc .oneshot(ep(Metadata::default())) .await .expect("Connect must not fail"); @@ -201,13 +201,15 @@ mod test { protocol: None, }; let buf = hdr.encode_prefaced_buf().expect("Must encode"); - future::ready(Ok::<_, io::Error>(Io { - alpn: Some(tls::NegotiatedProtocolRef(PROTOCOL)), - io: tokio_test::io::Builder::new() - .write(&buf[..]) - .write(b"hello") - .build(), - })) + let io = tokio_test::io::Builder::new() + .write(&buf[..]) + .write(b"hello") + .build(); + let meta = tls::ConnectMeta { + socket: Local(ClientAddr(([0, 0, 0, 0], 0).into())), + tls: Conditional::Some(Some(tls::NegotiatedProtocolRef(PROTOCOL).into())), + }; + future::ready(Ok::<_, io::Error>((io, meta))) }), }; @@ -220,7 +222,7 @@ mod test { )), None, )); - let mut io = svc.oneshot(e).await.expect("Connect must not fail"); + let (mut io, _meta) = svc.oneshot(e).await.expect("Connect must not fail"); io.write_all(b"hello").await.expect("Write must succeed"); } @@ -239,13 +241,15 @@ mod test { protocol: None, }; let buf = hdr.encode_prefaced_buf().expect("Must encode"); - future::ready(Ok::<_, io::Error>(Io { - alpn: Some(tls::NegotiatedProtocolRef(PROTOCOL)), - io: tokio_test::io::Builder::new() - .write(&buf[..]) - .write(b"hello") - .build(), - })) + let io = tokio_test::io::Builder::new() + .write(&buf[..]) + .write(b"hello") + .build(); + let meta = tls::ConnectMeta { + socket: Local(ClientAddr(([0, 0, 0, 0], 0).into())), + tls: Conditional::Some(Some(tls::NegotiatedProtocolRef(PROTOCOL).into())), + }; + future::ready(Ok::<_, io::Error>((io, meta))) }), }; @@ -258,7 +262,7 @@ mod test { )), Some(http::uri::Authority::from_str("foo.bar.example.com:5555").unwrap()), )); - let mut io = svc.oneshot(e).await.expect("Connect must not fail"); + let (mut io, _meta) = svc.oneshot(e).await.expect("Connect must not fail"); io.write_all(b"hello").await.expect("Write must succeed"); } @@ -277,13 +281,15 @@ mod test { protocol: None, }; let buf = hdr.encode_prefaced_buf().expect("Must encode"); - future::ready(Ok::<_, io::Error>(Io { - alpn: Some(tls::NegotiatedProtocolRef(PROTOCOL)), - io: tokio_test::io::Builder::new() - .write(&buf[..]) - .write(b"hello") - .build(), - })) + let io = tokio_test::io::Builder::new() + .write(&buf[..]) + .write(b"hello") + .build(); + let meta = tls::ConnectMeta { + socket: Local(ClientAddr(([0, 0, 0, 0], 0).into())), + tls: Conditional::Some(Some(tls::NegotiatedProtocolRef(PROTOCOL).into())), + }; + future::ready(Ok::<_, io::Error>((io, meta))) }), }; @@ -296,62 +302,7 @@ mod test { )), None, )); - let mut io = svc.oneshot(e).await.expect("Connect must not fail"); + let (mut io, _meta) = svc.oneshot(e).await.expect("Connect must not fail"); io.write_all(b"hello").await.expect("Write must succeed"); } - - #[pin_project] - pub struct Io { - #[pin] - io: tokio_test::io::Mock, - alpn: Option>, - } - - impl tls::HasNegotiatedProtocol for Io { - fn negotiated_protocol(&self) -> Option> { - self.alpn - } - } - - impl io::AsyncRead for Io { - #[inline] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &mut io::ReadBuf<'_>, - ) -> io::Poll<()> { - self.project().io.poll_read(cx, buf) - } - } - - impl io::AsyncWrite for Io { - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - self.project().io.poll_shutdown(cx) - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> { - self.project().io.poll_flush(cx) - } - - #[inline] - fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll { - self.project().io.poll_write(cx, buf) - } - - #[inline] - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[io::IoSlice<'_>], - ) -> io::Poll { - self.project().io.poll_write_vectored(cx, buf) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.io.is_write_vectored() - } - } } diff --git a/linkerd/app/test/src/connect.rs b/linkerd/app/test/src/connect.rs index 2e4f897dd5..c57a321ef0 100644 --- a/linkerd/app/test/src/connect.rs +++ b/linkerd/app/test/src/connect.rs @@ -1,14 +1,16 @@ use linkerd_app_core::{ - svc::Param, - transport::{Remote, ServerAddr}, + svc::{Param, Service}, + transport::{ClientAddr, Local, Remote, ServerAddr}, +}; +use std::{ + collections::HashMap, + fmt, + future::Future, + net::SocketAddr, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, }; -use std::collections::HashMap; -use std::fmt; -use std::future::Future; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::{Arc, Mutex}; -use std::task::{Context, Poll}; use tracing::instrument::{Instrument, Instrumented}; mod io { @@ -16,9 +18,10 @@ mod io { pub use tokio_test::io::*; } -type ConnectFn = Box ConnectFuture + Send>; +type ConnectFn = Box ConnectFuture + Send>; -pub type ConnectFuture = Pin> + Send + 'static>>; +pub type ConnectFuture = + Pin)>> + Send + 'static>>; #[derive(Clone)] pub struct Connect { @@ -28,11 +31,11 @@ pub struct Connect { #[derive(Clone)] pub struct NoRawTcp; -impl tower::Service for Connect +impl Service for Connect where - E: Clone + fmt::Debug + Param>, + T: Clone + fmt::Debug + Param>, { - type Response = io::BoxedIo; + type Response = (io::BoxedIo, Local); type Future = Instrumented; type Error = io::Error; @@ -40,14 +43,14 @@ where Poll::Ready(Ok(())) } - fn call(&mut self, endpoint: E) -> Self::Future { - let Remote(ServerAddr(addr)) = endpoint.param(); + fn call(&mut self, target: T) -> Self::Future { + let Remote(ServerAddr(addr)) = target.param(); let span = tracing::info_span!("connect", %addr); let f = span.in_scope(|| { tracing::trace!("connecting..."); let mut endpoints = self.endpoints.lock().unwrap(); match endpoints.get_mut(&addr) { - Some(f) => (f)(endpoint), + Some(f) => (f)(target), None => panic!( "did not expect to connect to the endpoint {} not in {:?}", addr, @@ -60,7 +63,7 @@ where } impl tower::Service for NoRawTcp { - type Response = io::BoxedIo; + type Response = (io::BoxedIo, Local); type Future = Instrumented; type Error = io::Error; @@ -106,7 +109,8 @@ impl Connect { endpoint.into(), Box::new(move |endpoint| { let conn = on_connect(endpoint); - Box::pin(async move { conn }) + let local = Local(ClientAddr(([0, 0, 0, 0], 0).into())); + Box::pin(async move { conn.map(move |c| (c, local)) }) }), ); self diff --git a/linkerd/meshtls/boring/src/client.rs b/linkerd/meshtls/boring/src/client.rs index 723836c4c8..2806e34ce4 100644 --- a/linkerd/meshtls/boring/src/client.rs +++ b/linkerd/meshtls/boring/src/client.rs @@ -2,9 +2,7 @@ use crate::creds::CredsRx; use linkerd_identity::Name; use linkerd_io as io; use linkerd_stack::{NewService, Service}; -use linkerd_tls::{ - client::AlpnProtocols, ClientTls, HasNegotiatedProtocol, NegotiatedProtocolRef, ServerId, -}; +use linkerd_tls::{client::AlpnProtocols, ClientTls, NegotiatedProtocolRef, ServerId}; use std::{future::Future, pin::Pin, sync::Arc, task::Context}; use tracing::debug; @@ -148,9 +146,9 @@ impl io::AsyncWrite for ClientIo { } } -impl HasNegotiatedProtocol for ClientIo { +impl ClientIo { #[inline] - fn negotiated_protocol(&self) -> Option> { + pub fn negotiated_protocol(&self) -> Option> { self.0 .ssl() .selected_alpn_protocol() diff --git a/linkerd/meshtls/rustls/src/client.rs b/linkerd/meshtls/rustls/src/client.rs index 37c8e5a050..a5eef085fc 100644 --- a/linkerd/meshtls/rustls/src/client.rs +++ b/linkerd/meshtls/rustls/src/client.rs @@ -1,7 +1,7 @@ use futures::prelude::*; use linkerd_io as io; use linkerd_stack::{NewService, Service}; -use linkerd_tls::{client::AlpnProtocols, ClientTls, HasNegotiatedProtocol, NegotiatedProtocolRef}; +use linkerd_tls::{client::AlpnProtocols, ClientTls, NegotiatedProtocolRef}; use std::{convert::TryFrom, pin::Pin, sync::Arc, task::Context}; use tokio::sync::watch; use tokio_rustls::rustls::{self, ClientConfig}; @@ -139,9 +139,9 @@ impl io::AsyncWrite for ClientIo { } } -impl HasNegotiatedProtocol for ClientIo { +impl ClientIo { #[inline] - fn negotiated_protocol(&self) -> Option> { + pub fn negotiated_protocol(&self) -> Option> { self.0 .get_ref() .1 diff --git a/linkerd/meshtls/rustls/src/server.rs b/linkerd/meshtls/rustls/src/server.rs index 592e9f0963..2cdcb6a381 100644 --- a/linkerd/meshtls/rustls/src/server.rs +++ b/linkerd/meshtls/rustls/src/server.rs @@ -2,9 +2,7 @@ use futures::prelude::*; use linkerd_identity::{LocalId, Name}; use linkerd_io as io; use linkerd_stack::{Param, Service}; -use linkerd_tls::{ - ClientId, HasNegotiatedProtocol, NegotiatedProtocol, NegotiatedProtocolRef, ServerTls, -}; +use linkerd_tls::{ClientId, NegotiatedProtocol, NegotiatedProtocolRef, ServerTls}; use std::{convert::TryFrom, pin::Pin, sync::Arc, task::Context}; use thiserror::Error; use tokio::sync::watch; @@ -190,9 +188,9 @@ impl io::AsyncWrite for ServerIo { } } -impl HasNegotiatedProtocol for ServerIo { +impl ServerIo { #[inline] - fn negotiated_protocol(&self) -> Option> { + pub fn negotiated_protocol(&self) -> Option> { self.0 .get_ref() .1 diff --git a/linkerd/meshtls/src/client.rs b/linkerd/meshtls/src/client.rs index a55bfe5d7d..42a80d299b 100644 --- a/linkerd/meshtls/src/client.rs +++ b/linkerd/meshtls/src/client.rs @@ -1,6 +1,6 @@ use linkerd_io as io; use linkerd_stack::{NewService, Service}; -use linkerd_tls::{ClientTls, HasNegotiatedProtocol, NegotiatedProtocolRef}; +use linkerd_tls::{ClientTls, NegotiatedProtocol}; use std::{ future::Future, pin::Pin, @@ -91,7 +91,7 @@ impl Service for Connect where I: io::AsyncRead + io::AsyncWrite + Send + Unpin + 'static, { - type Response = ClientIo; + type Response = (ClientIo, Option); type Error = io::Error; type Future = ConnectFuture; @@ -103,6 +103,7 @@ where #[cfg(feature = "rustls")] Self::Rustls(connect) => >::poll_ready(connect, cx), + #[cfg(not(feature = "__has_any_tls_impls"))] _ => crate::no_tls!(cx), } @@ -116,6 +117,7 @@ where #[cfg(feature = "rustls")] Self::Rustls(connect) => ConnectFuture::Rustls(connect.call(io)), + #[cfg(not(feature = "__has_any_tls_impls"))] _ => crate::no_tls!(io), } @@ -128,7 +130,7 @@ impl Future for ConnectFuture where I: io::AsyncRead + io::AsyncWrite + Unpin, { - type Output = io::Result>; + type Output = io::Result<(ClientIo, Option)>; #[inline] fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -136,14 +138,21 @@ where #[cfg(feature = "boring")] ConnectFutureProj::Boring(f) => { let res = futures::ready!(f.poll(cx)); - Poll::Ready(res.map(ClientIo::Boring)) + Poll::Ready(res.map(|io| { + let np = io.negotiated_protocol().map(|np| np.to_owned()); + (ClientIo::Boring(io), np) + })) } #[cfg(feature = "rustls")] ConnectFutureProj::Rustls(f) => { let res = futures::ready!(f.poll(cx)); - Poll::Ready(res.map(ClientIo::Rustls)) + Poll::Ready(res.map(|io| { + let np = io.negotiated_protocol().map(|np| np.to_owned()); + (ClientIo::Rustls(io), np) + })) } + #[cfg(not(feature = "__has_any_tls_impls"))] _ => crate::no_tls!(cx), } @@ -242,21 +251,6 @@ impl io::AsyncWrite for ClientIo { } } -impl HasNegotiatedProtocol for ClientIo { - #[inline] - fn negotiated_protocol(&self) -> Option> { - match self { - #[cfg(feature = "boring")] - Self::Boring(io) => io.negotiated_protocol(), - - #[cfg(feature = "rustls")] - Self::Rustls(io) => io.negotiated_protocol(), - #[cfg(not(feature = "__has_any_tls_impls"))] - _ => crate::no_tls!(), - } - } -} - impl io::PeerAddr for ClientIo { #[inline] fn peer_addr(&self) -> io::Result { diff --git a/linkerd/meshtls/tests/util.rs b/linkerd/meshtls/tests/util.rs index fdd8a7a1d7..bb0efcc793 100644 --- a/linkerd/meshtls/tests/util.rs +++ b/linkerd/meshtls/tests/util.rs @@ -245,7 +245,7 @@ where }) .expect("send result"); } - Ok(conn) => { + Ok((conn, _)) => { let result = client(conn).instrument(tracing::info_span!("client")).await; sender .send(Transported { tls, result }) diff --git a/linkerd/proxy/http/src/client.rs b/linkerd/proxy/http/src/client.rs index 11c01b57ad..dcf142da5e 100644 --- a/linkerd/proxy/http/src/client.rs +++ b/linkerd/proxy/http/src/client.rs @@ -9,13 +9,12 @@ use crate::{h1, h2, orig_proto}; use futures::prelude::*; use linkerd_error::{Error, Result}; use linkerd_http_box::BoxBody; -use linkerd_stack::{layer, Param}; +use linkerd_stack::{layer, MakeConnection, Param, Service, ServiceExt}; use std::{ marker::PhantomData, pin::Pin, task::{Context, Poll}, }; -use tower::ServiceExt; use tracing::instrument::{Instrument, Instrumented}; use tracing::{debug, debug_span}; @@ -68,10 +67,10 @@ impl tower::Service for MakeClient where T: Clone + Send + Sync + 'static, T: Param, - C: tower::make::MakeConnection + Clone + Unpin + Send + Sync + 'static, + C: MakeConnection + Clone + Unpin + Send + Sync + 'static, + C::Connection: Unpin + Send, + C::Metadata: Send, C::Future: Unpin + Send + 'static, - C::Error: Into, - C::Connection: Unpin + Send + 'static, B: hyper::body::HttpBody + Send + 'static, B::Data: Send, B::Error: Into + Send + Sync, @@ -131,11 +130,11 @@ impl Clone for MakeClient { type RspFuture = Pin>> + Send + 'static>>; -impl tower::Service> for Client +impl Service> for Client where T: Clone + Send + Sync + 'static, - C: tower::make::MakeConnection + Clone + Send + Sync + 'static, - C::Connection: Unpin + Send + 'static, + C: MakeConnection + Clone + Send + Sync + 'static, + C::Connection: Unpin + Send, C::Future: Unpin + Send + 'static, C::Error: Into, B: hyper::body::HttpBody + Send + 'static, diff --git a/linkerd/proxy/http/src/glue.rs b/linkerd/proxy/http/src/glue.rs index ef64a897f0..3609318d58 100644 --- a/linkerd/proxy/http/src/glue.rs +++ b/linkerd/proxy/http/src/glue.rs @@ -3,12 +3,15 @@ use bytes::Bytes; use futures::TryFuture; use hyper::body::HttpBody; use hyper::client::connect as hyper_connect; -use linkerd_error::Error; +use linkerd_error::{Error, Result}; use linkerd_io::{self as io, AsyncRead, AsyncWrite}; +use linkerd_stack::{MakeConnection, Service}; use pin_project::{pin_project, pinned_drop}; -use std::future::Future; -use std::pin::Pin; -use std::task::{Context, Poll}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; use tracing::debug; /// Provides optional HTTP/1.1 upgrade support on the body. @@ -166,13 +169,11 @@ impl HyperConnect { } } -impl tower::Service for HyperConnect +impl Service for HyperConnect where - C: tower::make::MakeConnection + Clone + Send + Sync, - C::Error: Into, - C::Future: TryFuture + Unpin + Send + 'static, - ::Error: Into, - C::Connection: Unpin + Send + 'static, + C: MakeConnection + Clone + Send + Sync, + C::Connection: Unpin + Send, + C::Future: Unpin + Send + 'static, T: Clone + Send + Sync, { type Response = Connection; @@ -185,22 +186,22 @@ where fn call(&mut self, _dst: hyper::Uri) -> Self::Future { HyperConnectFuture { - inner: self.connect.make_connection(self.target.clone()), + inner: self.connect.connect(self.target.clone()), absolute_form: self.absolute_form, } } } -impl Future for HyperConnectFuture +impl Future for HyperConnectFuture where - F: TryFuture + 'static, + F: TryFuture + 'static, F::Error: Into, { - type Output = Result, Error>; + type Output = Result>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - let transport = futures::ready!(this.inner.try_poll(cx)).map_err(Into::into)?; + let (transport, _) = futures::ready!(this.inner.try_poll(cx)).map_err(Into::into)?; Poll::Ready(Ok(Connection { transport, absolute_form: *this.absolute_form, diff --git a/linkerd/proxy/http/src/h1.rs b/linkerd/proxy/http/src/h1.rs index b0e87c69ef..73803faff4 100644 --- a/linkerd/proxy/http/src/h1.rs +++ b/linkerd/proxy/http/src/h1.rs @@ -9,6 +9,7 @@ use http::{ }; use linkerd_error::{Error, Result}; use linkerd_http_box::BoxBody; +use linkerd_stack::MakeConnection; use std::{future::Future, mem, pin::Pin, time::Duration}; use tracing::{debug, trace}; @@ -65,10 +66,9 @@ type RspFuture = Pin>> + impl Client where T: Clone + Send + Sync + 'static, - C: tower::make::MakeConnection + Clone + Send + Sync + 'static, - C::Connection: Unpin + Send + 'static, + C: MakeConnection + Clone + Send + Sync + 'static, + C::Connection: Unpin + Send, C::Future: Unpin + Send + 'static, - C::Error: Into, B: hyper::body::HttpBody + Send + 'static, B::Data: Send, B::Error: Into + Send + Sync, diff --git a/linkerd/proxy/http/src/h2.rs b/linkerd/proxy/http/src/h2.rs index 785c83f391..7eb67a2854 100644 --- a/linkerd/proxy/http/src/h2.rs +++ b/linkerd/proxy/http/src/h2.rs @@ -6,14 +6,14 @@ use hyper::{ client::conn::{self, SendRequest}, }; use linkerd_error::{Error, Result}; -use std::time::Duration; +use linkerd_stack::{MakeConnection, Service}; use std::{ future::Future, marker::PhantomData, pin::Pin, task::{Context, Poll}, + time::Duration, }; -use tokio::io::{AsyncRead, AsyncWrite}; use tracing::instrument::Instrument; use tracing::{debug, debug_span, trace_span}; @@ -60,12 +60,12 @@ impl Clone for Connect { type ConnectFuture = Pin>> + Send + 'static>>; -impl tower::Service for Connect +impl Service for Connect where - C: tower::make::MakeConnection, + C: MakeConnection, + C::Connection: Send + Unpin + 'static, + C::Metadata: Send, C::Future: Send + 'static, - C::Connection: AsyncRead + AsyncWrite + Unpin + Send + 'static, - C::Error: Into, B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into + Send + Sync, @@ -88,12 +88,12 @@ where let connect = self .connect - .make_connection(target) + .connect(target) .instrument(trace_span!("connect")); Box::pin( async move { - let io = connect.err_into::().await?; + let (io, _meta) = connect.err_into::().await?; let mut builder = conn::Builder::new(); builder .http2_only(true) diff --git a/linkerd/proxy/http/src/orig_proto.rs b/linkerd/proxy/http/src/orig_proto.rs index 4f4a385bed..2f2c126041 100644 --- a/linkerd/proxy/http/src/orig_proto.rs +++ b/linkerd/proxy/http/src/orig_proto.rs @@ -4,7 +4,7 @@ use http::header::{HeaderValue, TRANSFER_ENCODING}; use hyper::body::HttpBody; use linkerd_error::{Error, Result}; use linkerd_http_box::BoxBody; -use linkerd_stack::layer; +use linkerd_stack::{layer, MakeConnection, Service}; use std::{ future::Future, pin::Pin, @@ -47,13 +47,12 @@ impl Upgrade { } } -impl tower::Service> for Upgrade +impl Service> for Upgrade where T: Clone + Send + Sync + 'static, - C: tower::make::MakeConnection + Clone + Send + Sync + 'static, - C::Connection: Unpin + Send + 'static, + C: MakeConnection + Clone + Send + Sync + 'static, + C::Connection: Unpin + Send, C::Future: Unpin + Send + 'static, - C::Error: Into, B: hyper::body::HttpBody + Send + 'static, B::Data: Send, B::Error: Into + Send + Sync, @@ -235,9 +234,9 @@ impl Downgrade { type DowngradeFuture = future::MapOk T>; -impl tower::Service> for Downgrade +impl Service> for Downgrade where - S: tower::Service, Response = http::Response>, + S: Service, Response = http::Response>, { type Response = S::Response; type Error = S::Error; diff --git a/linkerd/proxy/tcp/src/forward.rs b/linkerd/proxy/tcp/src/forward.rs index ed2a493cb9..9b43e2ad46 100644 --- a/linkerd/proxy/tcp/src/forward.rs +++ b/linkerd/proxy/tcp/src/forward.rs @@ -1,6 +1,6 @@ use futures::prelude::*; use linkerd_duplex::Duplex; -use linkerd_error::Error; +use linkerd_error::{Error, Result}; use linkerd_stack::layer; use std::{ future::Future, @@ -35,9 +35,9 @@ where { type Response = (); type Error = Error; - type Future = - Pin> + Send + 'static>>; + type Future = Pin> + Send + 'static>>; + #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.connect.poll_ready(cx).map_err(Into::into) } diff --git a/linkerd/proxy/transport/src/connect.rs b/linkerd/proxy/transport/src/connect.rs index 9f721341af..4989ac2129 100644 --- a/linkerd/proxy/transport/src/connect.rs +++ b/linkerd/proxy/transport/src/connect.rs @@ -1,4 +1,4 @@ -use crate::{Keepalive, Remote, ServerAddr}; +use crate::{ClientAddr, Keepalive, Local, Remote, ServerAddr}; use linkerd_io as io; use linkerd_stack::{Param, Service}; use std::{ @@ -21,10 +21,9 @@ impl ConnectTcp { } impl>> Service for ConnectTcp { - type Response = io::ScopedIo; + type Response = (io::ScopedIo, Local); type Error = io::Error; - type Future = - Pin>> + Send + Sync + 'static>>; + type Future = Pin> + Send + Sync + 'static>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) @@ -38,12 +37,13 @@ impl>> Service for ConnectTcp { let io = TcpStream::connect(&addr).await?; super::set_nodelay_or_warn(&io); let io = super::set_keepalive_or_warn(io, keepalive)?; + let local_addr = io.local_addr()?; debug!( - local.addr = %io.local_addr().expect("cannot load local addr"), + local.addr = %local_addr, ?keepalive, "Connected", ); - Ok(io::ScopedIo::client(io)) + Ok((io::ScopedIo::client(io), Local(ClientAddr(local_addr)))) }) } } diff --git a/linkerd/stack/src/connect.rs b/linkerd/stack/src/connect.rs new file mode 100644 index 0000000000..49ce69c638 --- /dev/null +++ b/linkerd/stack/src/connect.rs @@ -0,0 +1,122 @@ +use crate::{layer, Service}; +use futures::prelude::*; +use linkerd_error::Error; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; + +/// A helper `Service` that drops metadata from a `MakeConnection` +#[derive(Clone, Debug)] +pub struct WithoutConnectionMetadata(S); + +/// A helper that coerces a `MakeConnection` into a `Service` +#[derive(Clone, Debug)] +pub struct MakeConnectionService(S); + +/// A helper trait that models a `Service` that creates client connections. +/// +/// Implementers should implement `Service` and not `MakeConnection`. `MakeConnection` should only +/// be used by consumers of these services. +pub trait MakeConnection { + /// An I/O type that represents a connection to the remote endpoint. + type Connection: AsyncRead + AsyncWrite; + + /// Metadata associated with the established connection. + type Metadata; + + type Error: Into; + + type Future: Future>; + + /// Determines whether the connector is ready to establish a connection. + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll>; + + /// Establishes a connection. + fn connect(&mut self, t: T) -> Self::Future; + + /// Returns a new `Service` that drops the connection metadata from returned values. + fn without_connection_metadata(self) -> WithoutConnectionMetadata + where + Self: Sized, + { + WithoutConnectionMetadata(self) + } + + /// Coerces a `MakeConnection` into a `Service`. + fn into_service(self) -> MakeConnectionService + where + Self: Sized, + { + MakeConnectionService(self) + } +} + +impl MakeConnection for S +where + S: Service, + S::Error: Into, + I: AsyncRead + AsyncWrite, +{ + type Connection = I; + type Metadata = M; + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + Service::poll_ready(self, cx) + } + + #[inline] + fn connect(&mut self, t: T) -> Self::Future { + Service::call(self, t) + } +} + +// === impl MakeConnectionService === + +impl Service for MakeConnectionService +where + S: MakeConnection, +{ + type Response = (S::Connection, S::Metadata); + type Error = S::Error; + type Future = S::Future; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + #[inline] + fn call(&mut self, t: T) -> Self::Future { + self.0.connect(t) + } +} + +// === impl WithoutConnectionMetadata === + +impl WithoutConnectionMetadata { + pub fn layer() -> impl layer::Layer + Clone { + layer::mk(WithoutConnectionMetadata) + } +} + +impl Service for WithoutConnectionMetadata +where + S: MakeConnection, +{ + type Response = S::Connection; + type Error = S::Error; + type Future = + futures::future::MapOk S::Connection>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + #[inline] + fn call(&mut self, t: T) -> Self::Future { + self.0.connect(t).map_ok(|(conn, _)| conn) + } +} diff --git a/linkerd/stack/src/lib.rs b/linkerd/stack/src/lib.rs index 011d7456fc..a83531599d 100644 --- a/linkerd/stack/src/lib.rs +++ b/linkerd/stack/src/lib.rs @@ -6,6 +6,7 @@ mod arc_new_service; mod box_future; mod box_service; +mod connect; mod either; mod fail; mod fail_on_error; @@ -29,6 +30,7 @@ pub use self::{ arc_new_service::ArcNewService, box_future::BoxFuture, box_service::{BoxService, BoxServiceLayer}, + connect::{MakeConnection, WithoutConnectionMetadata}, either::{Either, NewEither}, fail::Fail, fail_on_error::FailOnError, diff --git a/linkerd/tls/src/client.rs b/linkerd/tls/src/client.rs index 340534f8a0..3243393d1e 100644 --- a/linkerd/tls/src/client.rs +++ b/linkerd/tls/src/client.rs @@ -1,9 +1,9 @@ -use crate::{HasNegotiatedProtocol, NegotiatedProtocolRef}; +use crate::NegotiatedProtocol; use futures::prelude::*; use linkerd_conditional::Conditional; use linkerd_identity as id; use linkerd_io as io; -use linkerd_stack::{layer, NewService, Oneshot, Param, Service, ServiceExt}; +use linkerd_stack::{layer, MakeConnection, NewService, Oneshot, Param, Service, ServiceExt}; use std::{ fmt, future::Future, @@ -58,9 +58,19 @@ pub struct Client { #[pin_project::pin_project(project = ConnectProj)] #[derive(Debug)] -pub enum Connect> { - Connect(#[pin] F, Option), - Handshake(#[pin] Oneshot), +pub enum Connect, M> { + Connect(#[pin] F, Option>), + Handshake { + #[pin] + inner: Oneshot, + state: Option<(Conditional<(), NoClientTls>, M)>, + }, +} + +#[derive(Clone, Debug)] +pub struct ConnectMeta { + pub socket: M, + pub tls: Conditional, NoClientTls>, } // === impl ClientTls === @@ -85,20 +95,23 @@ impl Client { } } -impl Service for Client +impl Service for Client where T: Param, L: NewService, - C: Service, - C::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin, + C: MakeConnection, + C::Connection: io::AsyncRead + io::AsyncWrite + Send + Unpin, + C::Metadata: Send + Unpin, C::Future: Send + 'static, - H: Service + Send + 'static, - H::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin + HasNegotiatedProtocol, + H: Service), Error = io::Error> + + Send + + 'static, H::Future: Send + 'static, + I: io::AsyncRead + io::AsyncWrite + Send + Unpin, { - type Response = io::EitherIo; + type Response = (io::EitherIo, ConnectMeta); type Error = io::Error; - type Future = Connect; + type Future = Connect; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -107,45 +120,58 @@ where fn call(&mut self, target: T) -> Self::Future { let handshake = match target.param() { - Conditional::Some(tls) => Some(self.identity.new_service(tls)), + Conditional::Some(tls) => Conditional::Some(self.identity.new_service(tls)), Conditional::None(reason) => { debug!(%reason, "Peer does not support TLS"); - None + Conditional::None(reason) } }; - let connect = self.inner.call(target); - Connect::Connect(connect, handshake) + let connect = self.inner.connect(target); + Connect::Connect(connect, Some(handshake)) } } -impl Future for Connect +impl Future for Connect where - F: TryFuture, - H: Service, - H::Response: HasNegotiatedProtocol, + F: TryFuture, + H: Service), Error = io::Error>, { - type Output = io::Result>; + type Output = io::Result<(io::EitherIo, ConnectMeta)>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { loop { match self.as_mut().project() { ConnectProj::Connect(fut, tls) => { - let io = futures::ready!(fut.try_poll(cx))?; - match tls.take() { - None => return Poll::Ready(Ok(io::EitherIo::Left(io))), - Some(tls) => self.set(Connect::Handshake(tls.oneshot(io))), + let (io, socket) = futures::ready!(fut.try_poll(cx))?; + match tls.take().expect("tls handshake must be set") { + Conditional::Some(tls) => self.set(Connect::Handshake { + inner: tls.oneshot(io), + state: Some((Conditional::Some(()), socket)), + }), + Conditional::None(reason) => { + let meta = ConnectMeta { + socket, + tls: Conditional::None(reason), + }; + return Poll::Ready(Ok((io::EitherIo::Left(io), meta))); + } } } - ConnectProj::Handshake(fut) => { - let io = futures::ready!(fut.try_poll(cx))?; + ConnectProj::Handshake { inner, state } => { + let (io, alpn) = futures::ready!(inner.try_poll(cx))?; debug!( - alpn = io - .negotiated_protocol() - .and_then(|NegotiatedProtocolRef(p)| std::str::from_utf8(p).ok()) + alpn = alpn + .as_ref() + .and_then(|NegotiatedProtocol(ref p)| std::str::from_utf8(p).ok()) .map(tracing::field::display) ); - return Poll::Ready(Ok(io::EitherIo::Right(io))); + let (tls, socket) = state.take().expect("metadata must be set"); + let meta = ConnectMeta { + socket, + tls: tls.map(move |()| alpn), + }; + return Poll::Ready(Ok((io::EitherIo::Right(io), meta))); } } } diff --git a/linkerd/tls/src/lib.rs b/linkerd/tls/src/lib.rs index baa1c8b805..f8a1542037 100755 --- a/linkerd/tls/src/lib.rs +++ b/linkerd/tls/src/lib.rs @@ -5,18 +5,12 @@ pub mod client; pub mod server; pub use linkerd_identity::LocalId; -use linkerd_io as io; pub use self::{ - client::{Client, ClientTls, ConditionalClientTls, NoClientTls, ServerId}, + client::{Client, ClientTls, ConditionalClientTls, ConnectMeta, NoClientTls, ServerId}, server::{ClientId, ConditionalServerTls, NewDetectTls, NoServerTls, ServerTls}, }; -/// A trait implemented by transport streams to indicate its negotiated protocol. -pub trait HasNegotiatedProtocol { - fn negotiated_protocol(&self) -> Option>; -} - #[derive(Clone, Eq, PartialEq, Hash)] pub struct NegotiatedProtocol(pub Vec); @@ -56,39 +50,3 @@ impl std::fmt::Debug for NegotiatedProtocolRef<'_> { } } } - -impl HasNegotiatedProtocol for tokio::net::TcpStream { - #[inline] - fn negotiated_protocol(&self) -> Option> { - None - } -} - -impl HasNegotiatedProtocol for io::ScopedIo { - #[inline] - fn negotiated_protocol(&self) -> Option> { - self.get_ref().negotiated_protocol() - } -} - -impl HasNegotiatedProtocol for io::EitherIo -where - L: HasNegotiatedProtocol, - R: HasNegotiatedProtocol, -{ - #[inline] - fn negotiated_protocol(&self) -> Option> { - match self { - io::EitherIo::Left(l) => l.negotiated_protocol(), - io::EitherIo::Right(r) => r.negotiated_protocol(), - } - } -} - -/// Needed for tests. -impl HasNegotiatedProtocol for io::BoxedIo { - #[inline] - fn negotiated_protocol(&self) -> Option> { - None - } -} diff --git a/linkerd/transport-metrics/src/client.rs b/linkerd/transport-metrics/src/client.rs index a3bf5755d3..00f844eb3d 100644 --- a/linkerd/transport-metrics/src/client.rs +++ b/linkerd/transport-metrics/src/client.rs @@ -1,6 +1,6 @@ use super::{Metrics, Sensor, SensorIo}; use futures::{ready, TryFuture}; -use linkerd_stack::{layer, ExtractParam, Service}; +use linkerd_stack::{layer, ExtractParam, MakeConnection, Service}; use pin_project::pin_project; use std::{ future::Future, @@ -17,7 +17,7 @@ pub struct Client { } #[pin_project] -pub struct Connect { +pub struct ConnectFuture { #[pin] inner: F, metrics: Option>, @@ -37,11 +37,11 @@ impl Client { impl Service for Client where P: ExtractParam, T>, - S: Service, + S: MakeConnection, { - type Response = SensorIo; + type Response = (SensorIo, S::Metadata); type Error = S::Error; - type Future = Connect; + type Future = ConnectFuture; #[inline] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -50,29 +50,29 @@ where fn call(&mut self, target: T) -> Self::Future { let metrics = self.params.extract_param(&target); - let inner = self.inner.call(target); - Connect { + let inner = self.inner.connect(target); + ConnectFuture { metrics: Some(metrics), inner, } } } -// === impl Connect === +// === impl ConnectFuture === -impl Future for Connect { - type Output = Result, F::Error>; +impl> Future for ConnectFuture { + type Output = Result<(SensorIo, M), F::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - let io = ready!(this.inner.try_poll(cx))?; + let (io, meta) = ready!(this.inner.try_poll(cx))?; debug!("client connection open"); let metrics = this .metrics .take() .expect("future must not be polled after ready"); - let t = SensorIo::new(io, Sensor::open(metrics)); - Poll::Ready(Ok(t)) + let io = SensorIo::new(io, Sensor::open(metrics)); + Poll::Ready(Ok((io, meta))) } }