Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): Fix TLS accept w/ peer certs #535

Merged
merged 4 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/src/tls_client_auth/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ pub struct EchoServer;
#[tonic::async_trait]
impl pb::echo_server::Echo for EchoServer {
async fn unary_echo(&self, request: Request<EchoRequest>) -> EchoResult<EchoResponse> {
if let Some(certs) = request.peer_certs() {
println!("Got {} peer certs!", certs.len());
}
let certs = request
.peer_certs()
.expect("Client did not send its certs!");

println!("Got {} peer certs!", certs.len());

let message = request.into_inner().message;
Ok(Response::new(EchoResponse { message }))
Expand Down
1 change: 1 addition & 0 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ transport = [
"tokio",
"tower",
"tracing-futures",
"tokio/macros"
]
tls = ["transport", "tokio-rustls"]
tls-roots = ["tls", "rustls-native-certs"]
Expand Down
30 changes: 12 additions & 18 deletions tonic/src/transport/server/conn.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#[cfg(feature = "tls")]
use super::TlsStream;
use crate::transport::Certificate;
use hyper::server::conn::AddrStream;
use std::net::SocketAddr;
use tokio::net::TcpStream;
#[cfg(feature = "tls")]
use tokio_rustls::rustls::Session;
use tokio_rustls::{rustls::Session, server::TlsStream};

/// Trait that connected IO resources implement.
///
Expand Down Expand Up @@ -39,24 +37,20 @@ impl Connected for TcpStream {
#[cfg(feature = "tls")]
impl<T: Connected> Connected for TlsStream<T> {
fn remote_addr(&self) -> Option<SocketAddr> {
if let Some((inner, _)) = self.get_ref() {
inner.remote_addr()
} else {
None
}
let (inner, _) = self.get_ref();

inner.remote_addr()
}

fn peer_certs(&self) -> Option<Vec<Certificate>> {
if let Some((_, session)) = self.get_ref() {
if let Some(certs) = session.get_peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.collect();
Some(certs)
} else {
None
}
let (_, session) = self.get_ref();

if let Some(certs) = session.get_peer_certificates() {
let certs = certs
.into_iter()
.map(|c| Certificate::from_pem(c.0))
.collect();
Some(certs)
} else {
None
}
Expand Down
209 changes: 97 additions & 112 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ use std::{
};
use tokio::io::{AsyncRead, AsyncWrite};

#[cfg_attr(not(feature = "tls"), allow(unused_variables))]
#[cfg(not(feature = "tls"))]
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server,
_server: Server,
) -> impl Stream<Item = Result<ServerIo, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
Expand All @@ -26,145 +26,130 @@ where
async_stream::try_stream! {
futures_util::pin_mut!(incoming);


while let Some(stream) = incoming.try_next().await? {
#[cfg(feature = "tls")]
{
if let Some(tls) = &server.tls {
let io = tls.accept(stream);
yield ServerIo::new(io);
continue;
}
}

yield ServerIo::new(stream);
}
}
}

pub(crate) struct TcpIncoming {
inner: AddrIncoming,
}
#[cfg(feature = "tls")]
pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server,
) -> impl Stream<Item = Result<ServerIo, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
futures_util::pin_mut!(incoming);

impl TcpIncoming {
pub(crate) fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::bind(&addr)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
}
#[cfg(feature = "tls")]
let mut tasks = futures_util::stream::futures_unordered::FuturesUnordered::new();

impl Stream for TcpIncoming {
type Item = Result<AddrStream, std::io::Error>;
loop {
match select(&mut incoming, &mut tasks).await {
SelectOutput::Incoming(stream) => {
if let Some(tls) = &server.tls {
let tls = tls.clone();

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_accept(cx)
}
}
let accept = tokio::spawn(async move {
let io = tls.accept(stream).await?;
Ok(ServerIo::new(io))
});

// tokio_rustls::server::TlsStream doesn't expose constructor methods,
// so we have to TlsAcceptor::accept and handshake to have access to it
// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first
#[cfg(feature = "tls")]
pub(crate) struct TlsStream<IO> {
state: State<IO>,
}
tasks.push(accept);
} else {
yield ServerIo::new(stream);
}
}

#[cfg(feature = "tls")]
enum State<IO> {
Handshaking(tokio_rustls::Accept<IO>),
Streaming(tokio_rustls::server::TlsStream<IO>),
}
SelectOutput::Io(io) => {
yield io;
}

#[cfg(feature = "tls")]
impl<IO> TlsStream<IO> {
pub(crate) fn new(accept: tokio_rustls::Accept<IO>) -> Self {
TlsStream {
state: State::Handshaking(accept),
}
}
SelectOutput::Err(e) => {
tracing::error!(message = "Accept loop error.", error = %e);
}

pub(crate) fn get_ref(&self) -> Option<(&IO, &tokio_rustls::rustls::ServerSession)> {
if let State::Streaming(tls) = &self.state {
Some(tls.get_ref())
} else {
None
SelectOutput::Done => {
break;
}
}
}
}
}

#[cfg(feature = "tls")]
impl<IO> AsyncRead for TlsStream<IO>
async fn select<IO, IE>(
incoming: &mut (impl Stream<Item = Result<IO, IE>> + Unpin),
tasks: &mut futures_util::stream::futures_unordered::FuturesUnordered<
tokio::task::JoinHandle<Result<ServerIo, crate::Error>>,
>,
) -> SelectOutput<IO>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
use std::future::Future;

let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => {
match futures_core::ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
use futures_util::StreamExt;

if tasks.is_empty() {
return match incoming.try_next().await {
Ok(Some(stream)) => SelectOutput::Incoming(stream),
Ok(None) => SelectOutput::Done,
Err(e) => SelectOutput::Err(e.into()),
};
}

tokio::select! {
stream = incoming.try_next() => {
match stream {
Ok(Some(stream)) => SelectOutput::Incoming(stream),
Ok(None) => SelectOutput::Done,
Err(e) => SelectOutput::Err(e.into()),
}
State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}

#[cfg(feature = "tls")]
impl<IO> AsyncWrite for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
use std::future::Future;

let pin = self.get_mut();
match pin.state {
State::Handshaking(ref mut accept) => {
match futures_core::ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
Err(err) => Poll::Ready(Err(err)),
}
accept = tasks.next() => {
match accept.expect("FuturesUnordered stream should never end") {
Ok(Ok(io)) => SelectOutput::Io(io),
Ok(Err(e)) => SelectOutput::Err(e),
Err(e) => SelectOutput::Err(e.into()),
}
State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf),
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx),
}
#[cfg(feature = "tls")]
enum SelectOutput<A> {
Incoming(A),
Io(ServerIo),
Err(crate::Error),
Done,
}

pub(crate) struct TcpIncoming {
inner: AddrIncoming,
}

impl TcpIncoming {
pub(crate) fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::bind(&addr)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx),
}
impl Stream for TcpIncoming {
type Item = Result<AddrStream, std::io::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_accept(cx)
}
}
2 changes: 1 addition & 1 deletion tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use super::service::TlsAcceptor;
use incoming::TcpIncoming;

#[cfg(feature = "tls")]
pub(crate) use incoming::TlsStream;
pub(crate) use tokio_rustls::server::TlsStream;

#[cfg(feature = "tls")]
use crate::transport::Error;
Expand Down
6 changes: 2 additions & 4 deletions tonic/src/transport/service/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,12 @@ impl TlsAcceptor {
})
}

pub(crate) fn accept<IO>(&self, io: IO) -> TlsStream<IO>
pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
where
IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
{
let acceptor = RustlsAcceptor::from(self.inner.clone());
let accept = acceptor.accept(io);

TlsStream::new(accept)
acceptor.accept(io).await.map_err(Into::into)
}
}

Expand Down