diff --git a/examples/src/tls/server.rs b/examples/src/tls/server.rs index fc50c57ab..723549af0 100644 --- a/examples/src/tls/server.rs +++ b/examples/src/tls/server.rs @@ -6,7 +6,10 @@ use futures::Stream; use pb::{EchoRequest, EchoResponse}; use std::pin::Pin; use tonic::{ - transport::{Identity, Server, ServerTlsConfig}, + transport::{ + server::{TcpConnectInfo, TlsConnectInfo}, + Identity, Server, ServerTlsConfig, + }, Request, Response, Status, Streaming, }; @@ -19,6 +22,16 @@ pub struct EchoServer; #[tonic::async_trait] impl pb::echo_server::Echo for EchoServer { async fn unary_echo(&self, request: Request) -> EchoResult { + let conn_info = request + .extensions() + .get::>() + .unwrap(); + println!( + "Got a request from {:?} with info {:?}", + request.remote_addr(), + conn_info + ); + let message = request.into_inner().message; Ok(Response::new(EchoResponse { message })) } diff --git a/examples/src/uds/server.rs b/examples/src/uds/server.rs index 07be063c9..fb26bcdcf 100644 --- a/examples/src/uds/server.rs +++ b/examples/src/uds/server.rs @@ -24,7 +24,11 @@ impl Greeter for MyGreeter { &self, request: Request, ) -> Result, Status> { - println!("Got a request: {:?}", request); + #[cfg(unix)] + { + let conn_info = request.extensions().get::().unwrap(); + println!("Got a request {:?} with info {:?}", request, conn_info); + } let reply = hello_world::HelloReply { message: format!("Hello {}!", request.into_inner().name), @@ -64,6 +68,7 @@ async fn main() -> Result<(), Box> { mod unix { use std::{ pin::Pin, + sync::Arc, task::{Context, Poll}, }; @@ -73,7 +78,22 @@ mod unix { #[derive(Debug)] pub struct UnixStream(pub tokio::net::UnixStream); - impl Connected for UnixStream {} + impl Connected for UnixStream { + type ConnectInfo = UdsConnectInfo; + + fn connect_info(&self) -> Self::ConnectInfo { + UdsConnectInfo { + peer_addr: self.0.peer_addr().ok().map(Arc::new), + peer_cred: self.0.peer_cred().ok(), + } + } + } + + #[derive(Clone, Debug)] + pub struct UdsConnectInfo { + pub peer_addr: Option>, + pub peer_cred: Option, + } impl AsyncRead for UnixStream { fn poll_read( diff --git a/tests/integration_tests/tests/connect_info.rs b/tests/integration_tests/tests/connect_info.rs new file mode 100644 index 000000000..936eedac1 --- /dev/null +++ b/tests/integration_tests/tests/connect_info.rs @@ -0,0 +1,50 @@ +use futures_util::FutureExt; +use integration_tests::pb::{test_client, test_server, Input, Output}; +use std::time::Duration; +use tokio::sync::oneshot; +use tonic::{ + transport::{server::TcpConnectInfo, Endpoint, Server}, + Request, Response, Status, +}; + +#[tokio::test] +async fn getting_connect_info() { + struct Svc; + + #[tonic::async_trait] + impl test_server::Test for Svc { + async fn unary_call(&self, req: Request) -> Result, Status> { + assert!(req.remote_addr().is_some()); + assert!(req.extensions().get::().is_some()); + + Ok(Response::new(Output {})) + } + } + + let svc = test_server::TestServer::new(Svc); + + let (tx, rx) = oneshot::channel::<()>(); + + let jh = tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_shutdown("127.0.0.1:1400".parse().unwrap(), rx.map(drop)) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let channel = Endpoint::from_static("http://127.0.0.1:1400") + .connect() + .await + .unwrap(); + + let mut client = test_client::TestClient::new(channel); + + client.unary_call(Input {}).await.unwrap(); + + tx.send(()).unwrap(); + + jh.await.unwrap(); +} diff --git a/tonic/src/request.rs b/tonic/src/request.rs index c7780ebbd..212a43b16 100644 --- a/tonic/src/request.rs +++ b/tonic/src/request.rs @@ -1,6 +1,8 @@ use crate::metadata::{MetadataMap, MetadataValue}; +#[cfg(all(feature = "transport", feature = "tls"))] +use crate::transport::server::TlsConnectInfo; #[cfg(feature = "transport")] -use crate::transport::Certificate; +use crate::transport::{server::TcpConnectInfo, Certificate}; use crate::Extensions; use futures_core::Stream; #[cfg(feature = "transport")] @@ -15,13 +17,6 @@ pub struct Request { extensions: Extensions, } -#[derive(Clone)] -pub(crate) struct ConnectionInfo { - pub(crate) remote_addr: Option, - #[cfg(feature = "transport")] - pub(crate) peer_certs: Option>>, -} - /// Trait implemented by RPC request types. /// /// Types implementing this trait can be used as arguments to client RPC @@ -203,7 +198,32 @@ impl Request { /// does not implement `Connected`. This currently, /// only works on the server side. pub fn remote_addr(&self) -> Option { - self.get::()?.remote_addr + #[cfg(feature = "transport")] + { + #[cfg(feature = "tls")] + { + self.extensions() + .get::() + .and_then(|i| i.remote_addr()) + .or_else(|| { + self.extensions() + .get::>() + .and_then(|i| i.get_ref().remote_addr()) + }) + } + + #[cfg(not(feature = "tls"))] + { + self.extensions() + .get::() + .and_then(|i| i.remote_addr()) + } + } + + #[cfg(not(feature = "transport"))] + { + None + } } /// Get the peer certificates of the connected client. @@ -215,11 +235,17 @@ impl Request { #[cfg(feature = "transport")] #[cfg_attr(docsrs, doc(cfg(feature = "transport")))] pub fn peer_certs(&self) -> Option>> { - self.get::()?.peer_certs.clone() - } + #[cfg(feature = "tls")] + { + self.extensions() + .get::>() + .and_then(|i| i.peer_certs()) + } - pub(crate) fn get(&self) -> Option<&I> { - self.extensions.get::() + #[cfg(not(feature = "tls"))] + { + None + } } /// Set the max duration the request is allowed to take. diff --git a/tonic/src/transport/server/conn.rs b/tonic/src/transport/server/conn.rs index f5bbcfc08..ea304865b 100644 --- a/tonic/src/transport/server/conn.rs +++ b/tonic/src/transport/server/conn.rs @@ -1,58 +1,158 @@ -use crate::transport::Certificate; use hyper::server::conn::AddrStream; use std::net::SocketAddr; use tokio::net::TcpStream; + +#[cfg(feature = "tls")] +use crate::transport::Certificate; +#[cfg(feature = "tls")] +use std::sync::Arc; #[cfg(feature = "tls")] use tokio_rustls::{rustls::Session, server::TlsStream}; -/// Trait that connected IO resources implement. +/// Trait that connected IO resources implement and use to produce info about the connection. /// /// The goal for this trait is to allow users to implement /// custom IO types that can still provide the same connection /// metadata. +/// +/// # Example +/// +/// The `ConnectInfo` returned will be accessible through [request extensions][ext]: +/// +/// ``` +/// use tonic::{Request, transport::server::Connected}; +/// +/// // A `Stream` that yields connections +/// struct MyConnector {} +/// +/// // Return metadata about the connection as `MyConnectInfo` +/// impl Connected for MyConnector { +/// type ConnectInfo = MyConnectInfo; +/// +/// fn connect_info(&self) -> Self::ConnectInfo { +/// MyConnectInfo {} +/// } +/// } +/// +/// #[derive(Clone)] +/// struct MyConnectInfo { +/// // Metadata about your connection +/// } +/// +/// // The connect info can be accessed through request extensions: +/// # fn foo(request: Request<()>) { +/// let connect_info: &MyConnectInfo = request +/// .extensions() +/// .get::() +/// .expect("bug in tonic"); +/// # } +/// ``` +/// +/// [ext]: crate::Request::extensions pub trait Connected { - /// Return the remote address this IO resource is connected too. - fn remote_addr(&self) -> Option { - None - } + /// The connection info type the IO resources generates. + // all these bounds are necessary to set this as a request extension + type ConnectInfo: Clone + Send + Sync + 'static; - /// Return the set of connected peer TLS certificates. - fn peer_certs(&self) -> Option> { - None + /// Create type holding information about the connection. + fn connect_info(&self) -> Self::ConnectInfo; +} + +/// Connection info for standard TCP streams. +/// +/// This type will be accessible through [request extensions][ext] if you're using the default +/// non-TLS connector. +/// +/// See [`Connected`] for more details. +/// +/// [ext]: crate::Request::extensions +#[derive(Debug, Clone)] +pub struct TcpConnectInfo { + remote_addr: Option, +} + +impl TcpConnectInfo { + /// Return the remote address the IO resource is connected too. + pub fn remote_addr(&self) -> Option { + self.remote_addr } } impl Connected for AddrStream { - fn remote_addr(&self) -> Option { - Some(self.remote_addr()) + type ConnectInfo = TcpConnectInfo; + + fn connect_info(&self) -> Self::ConnectInfo { + TcpConnectInfo { + remote_addr: Some(self.remote_addr()), + } } } impl Connected for TcpStream { - fn remote_addr(&self) -> Option { - self.peer_addr().ok() + type ConnectInfo = TcpConnectInfo; + + fn connect_info(&self) -> Self::ConnectInfo { + TcpConnectInfo { + remote_addr: self.peer_addr().ok(), + } } } #[cfg(feature = "tls")] -impl Connected for TlsStream { - fn remote_addr(&self) -> Option { - let (inner, _) = self.get_ref(); - - inner.remote_addr() - } +impl Connected for TlsStream +where + T: Connected, +{ + type ConnectInfo = TlsConnectInfo; - fn peer_certs(&self) -> Option> { - let (_, session) = self.get_ref(); + fn connect_info(&self) -> Self::ConnectInfo { + let (inner, session) = self.get_ref(); + let inner = inner.connect_info(); - if let Some(certs) = session.get_peer_certificates() { + let certs = if let Some(certs) = session.get_peer_certificates() { let certs = certs .into_iter() .map(|c| Certificate::from_pem(c.0)) .collect(); - Some(certs) + Some(Arc::new(certs)) } else { None - } + }; + + TlsConnectInfo { inner, certs } + } +} + +/// Connection info for TLS streams. +/// +/// This type will be accessible through [request extensions][ext] if you're using a TLS connector. +/// +/// See [`Connected`] for more details. +/// +/// [ext]: crate::Request::extensions +#[cfg(feature = "tls")] +#[cfg_attr(docsrs, doc(cfg(feature = "tls")))] +#[derive(Debug, Clone)] +pub struct TlsConnectInfo { + inner: T, + certs: Option>>, +} + +#[cfg(feature = "tls")] +#[cfg_attr(docsrs, doc(cfg(feature = "tls")))] +impl TlsConnectInfo { + /// Get a reference to the underlying connection info. + pub fn get_ref(&self) -> &T { + &self.inner + } + + /// Get a mutable reference to the underlying connection info. + pub fn get_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Return the set of connected peer TLS certificates. + pub fn peer_certs(&self) -> Option>> { + self.certs.clone() } } diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index f9a21a0a9..686aef197 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -18,7 +18,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; pub(crate) fn tcp_incoming( incoming: impl Stream>, _server: Server, -) -> impl Stream> +) -> impl Stream, crate::Error>> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, @@ -26,10 +26,8 @@ where async_stream::try_stream! { futures_util::pin_mut!(incoming); - while let Some(stream) = incoming.try_next().await? { - - yield ServerIo::new(stream); + yield ServerIo::new_io(stream); } } } @@ -38,7 +36,7 @@ where pub(crate) fn tcp_incoming( incoming: impl Stream>, server: Server, -) -> impl Stream> +) -> impl Stream, crate::Error>> where IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, IE: Into, @@ -57,12 +55,12 @@ where let accept = tokio::spawn(async move { let io = tls.accept(stream).await?; - Ok(ServerIo::new(io)) + Ok(ServerIo::new_tls_io(io)) }); tasks.push(accept); } else { - yield ServerIo::new(stream); + yield ServerIo::new_io(stream); } } @@ -86,7 +84,7 @@ where async fn select( incoming: &mut (impl Stream> + Unpin), tasks: &mut futures_util::stream::futures_unordered::FuturesUnordered< - tokio::task::JoinHandle>, + tokio::task::JoinHandle, crate::Error>>, >, ) -> SelectOutput where @@ -124,7 +122,7 @@ where #[cfg(feature = "tls")] enum SelectOutput { Incoming(A), - Io(ServerIo), + Io(ServerIo), Err(crate::Error), Done, } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 4ec198237..7ccbc589d 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -7,10 +7,13 @@ mod recover_error; #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] mod tls; -pub use conn::Connected; +pub use conn::{Connected, TcpConnectInfo}; #[cfg(feature = "tls")] pub use tls::ServerTlsConfig; +#[cfg(feature = "tls")] +pub use conn::TlsConnectInfo; + #[cfg(feature = "tls")] use super::service::TlsAcceptor; @@ -24,7 +27,7 @@ use crate::transport::Error; use self::recover_error::RecoverError; use super::service::{GrpcTimeout, Or, Routes, ServerIo}; -use crate::{body::BoxBody, request::ConnectionInfo}; +use crate::body::BoxBody; use bytes::Bytes; use futures_core::Stream; use futures_util::{ @@ -38,6 +41,7 @@ use pin_project::pin_project; use std::{ fmt, future::Future, + marker::PhantomData, net::SocketAddr, pin::Pin, sync::Arc, @@ -458,6 +462,7 @@ impl Server { <>::Service as Service>>::Error: Into + Send, I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, F: Future, ResBody: http_body::Body + Send + Sync + 'static, @@ -487,6 +492,7 @@ impl Server { concurrency_limit, timeout, trace_interceptor, + _io: PhantomData, }; let server = hyper::Server::builder(incoming) @@ -674,6 +680,7 @@ where where I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, L: Layer>>, L::Service: Service, Response = Response> + Clone + Send + 'static, @@ -707,6 +714,7 @@ where where I: Stream>, IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static, + IO::ConnectInfo: Clone + Send + Sync + 'static, IE: Into, F: Future, L: Layer>>, @@ -749,7 +757,6 @@ impl fmt::Debug for Server { struct Svc { inner: S, trace_interceptor: Option, - conn_info: ConnectionInfo, } impl Service> for Svc @@ -782,8 +789,6 @@ where tracing::Span::none() }; - req.extensions_mut().insert(self.conn_info.clone()); - SvcFuture { inner: self.inner.call(req), span, @@ -823,15 +828,17 @@ impl fmt::Debug for Svc { } } -struct MakeSvc { +struct MakeSvc { concurrency_limit: Option, timeout: Option, inner: S, trace_interceptor: Option, + _io: PhantomData IO>, } -impl Service<&ServerIo> for MakeSvc +impl Service<&ServerIo> for MakeSvc where + IO: Connected, S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, S::Error: Into + Send, @@ -846,11 +853,8 @@ where Ok(()).into() } - fn call(&mut self, io: &ServerIo) -> Self::Future { - let conn_info = crate::request::ConnectionInfo { - remote_addr: io.remote_addr(), - peer_certs: io.peer_certs().map(Arc::new), - }; + fn call(&mut self, io: &ServerIo) -> Self::Future { + let conn_info = io.connect_info(); let svc = self.inner.clone(); let concurrency_limit = self.concurrency_limit; @@ -863,13 +867,35 @@ where .layer_fn(|s| GrpcTimeout::new(s, timeout)) .service(svc); - let svc = Svc { - inner: svc, - trace_interceptor, - conn_info, - }; - - let svc = BoxService::new(svc); + let svc = ServiceBuilder::new() + .layer(BoxService::layer()) + .map_request(move |mut request: Request| { + match &conn_info { + tower::util::Either::A(inner) => { + request.extensions_mut().insert(inner.clone()); + } + tower::util::Either::B(inner) => { + #[cfg(feature = "tls")] + { + request.extensions_mut().insert(inner.clone()); + request.extensions_mut().insert(inner.get_ref().clone()); + } + + #[cfg(not(feature = "tls"))] + { + // just a type check to make sure we didn't forget to + // insert this into the extensions + let _: &() = inner; + } + } + } + + request + }) + .service(Svc { + inner: svc, + trace_interceptor, + }); future::ready(Ok(svc)) } diff --git a/tonic/src/transport/service/io.rs b/tonic/src/transport/service/io.rs index 761c8ece9..0419336a7 100644 --- a/tonic/src/transport/service/io.rs +++ b/tonic/src/transport/service/io.rs @@ -1,10 +1,11 @@ -use crate::transport::{server::Connected, Certificate}; +use crate::transport::server::Connected; use hyper::client::connect::{Connected as HyperConnected, Connection}; use std::io; -use std::net::SocketAddr; use std::pin::Pin; use std::task::{Context, Poll}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +#[cfg(feature = "tls")] +use tokio_rustls::server::TlsStream; pub(in crate::transport) trait Io: AsyncRead + AsyncWrite + Send + 'static @@ -27,7 +28,16 @@ impl Connection for BoxedIo { } } -impl Connected for BoxedIo {} +impl Connected for BoxedIo { + type ConnectInfo = NoneConnectInfo; + + fn connect_info(&self) -> Self::ConnectInfo { + NoneConnectInfo + } +} + +#[derive(Copy, Clone)] +pub(crate) struct NoneConnectInfo; impl AsyncRead for BoxedIo { fn poll_read( @@ -57,52 +67,100 @@ impl AsyncWrite for BoxedIo { } } -pub(in crate::transport) trait ConnectedIo: Io + Connected {} +pub(crate) enum ServerIo { + Io(IO), + #[cfg(feature = "tls")] + TlsIo(TlsStream), +} + +use tower::util::Either; -impl ConnectedIo for T where T: Io + Connected {} +#[cfg(feature = "tls")] +type ServerIoConnectInfo = + Either<::ConnectInfo, as Connected>::ConnectInfo>; -pub(crate) struct ServerIo(Pin>); +#[cfg(not(feature = "tls"))] +type ServerIoConnectInfo = Either<::ConnectInfo, ()>; -impl ServerIo { - pub(in crate::transport) fn new(io: I) -> Self { - ServerIo(Box::pin(io)) +impl ServerIo { + pub(in crate::transport) fn new_io(io: IO) -> Self { + Self::Io(io) } -} -impl Connected for ServerIo { - fn remote_addr(&self) -> Option { - (&*self.0).remote_addr() + #[cfg(feature = "tls")] + pub(in crate::transport) fn new_tls_io(io: TlsStream) -> Self { + Self::TlsIo(io) } - fn peer_certs(&self) -> Option> { - (&self.0).peer_certs() + #[cfg(feature = "tls")] + pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo + where + IO: Connected, + TlsStream: Connected, + { + match self { + Self::Io(io) => Either::A(io.connect_info()), + Self::TlsIo(io) => Either::B(io.connect_info()), + } + } + + #[cfg(not(feature = "tls"))] + pub(in crate::transport) fn connect_info(&self) -> ServerIoConnectInfo + where + IO: Connected, + { + match self { + Self::Io(io) => Either::A(io.connect_info()), + } } } -impl AsyncRead for ServerIo { +impl AsyncRead for ServerIo +where + IO: AsyncWrite + AsyncRead + Unpin, +{ fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) + match &mut *self { + Self::Io(io) => Pin::new(io).poll_read(cx, buf), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_read(cx, buf), + } } } -impl AsyncWrite for ServerIo { +impl AsyncWrite for ServerIo +where + IO: AsyncWrite + AsyncRead + Unpin, +{ fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) + match &mut *self { + Self::Io(io) => Pin::new(io).poll_write(cx, buf), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_write(cx, buf), + } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) + match &mut *self { + Self::Io(io) => Pin::new(io).poll_flush(cx), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_flush(cx), + } } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) + match &mut *self { + Self::Io(io) => Pin::new(io).poll_shutdown(cx), + #[cfg(feature = "tls")] + Self::TlsIo(io) => Pin::new(io).poll_shutdown(cx), + } } }