Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple tls::accept from TcpStream #853

Merged
merged 5 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,7 @@ dependencies = [
name = "linkerd-io"
version = "0.1.0"
dependencies = [
"async-trait",
"bytes",
"futures",
"linkerd-errno",
Expand Down
10 changes: 2 additions & 8 deletions linkerd/app/inbound/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,8 @@ impl Config {
Service = impl svc::Service<I, Response = (), Error = Error, Future = impl Send>,
> + Clone
where
I: tls::accept::Detectable
+ io::AsyncRead
+ io::AsyncWrite
+ io::PeerAddr
+ Debug
+ Send
+ Unpin
+ 'static,
I: io::AsyncRead + io::AsyncWrite + io::Peek + io::PeerAddr,
I: Debug + Send + Sync + Unpin + 'static,
C: svc::Service<TcpEndpoint> + Clone + Send + Sync + Unpin + 'static,
C::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin + 'static,
C::Error: Into<Error>,
Expand Down
1 change: 1 addition & 0 deletions linkerd/io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ General I/O primitives.
default = []

[dependencies]
async-trait = "0.1"
futures = "0.3.9"
bytes = "1"
linkerd-errno = { path = "../errno" }
Expand Down
40 changes: 29 additions & 11 deletions linkerd/io/src/either.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{AsyncRead, AsyncWrite, IoSlice, PeerAddr, Poll, ReadBuf, Result};
use crate as io;
use pin_project::pin_project;
use std::{pin::Pin, task::Context};

Expand All @@ -9,45 +9,63 @@ pub enum EitherIo<L, R> {
Right(#[pin] R),
}

impl<L: PeerAddr, R: PeerAddr> PeerAddr for EitherIo<L, R> {
#[async_trait::async_trait]
impl<L, R> io::Peek for EitherIo<L, R>
where
L: io::Peek + Send + Sync,
R: io::Peek + Send + Sync,
{
async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
match self {
Self::Left(l) => l.peek(buf).await,
Self::Right(r) => r.peek(buf).await,
}
}
}

impl<L: io::PeerAddr, R: io::PeerAddr> io::PeerAddr for EitherIo<L, R> {
#[inline]
fn peer_addr(&self) -> Result<std::net::SocketAddr> {
fn peer_addr(&self) -> io::Result<std::net::SocketAddr> {
match self {
Self::Left(l) => l.peer_addr(),
Self::Right(r) => r.peer_addr(),
}
}
}

impl<L: AsyncRead, R: AsyncRead> AsyncRead for EitherIo<L, R> {
impl<L: io::AsyncRead, R: io::AsyncRead> io::AsyncRead for EitherIo<L, R> {
#[inline]
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut ReadBuf<'_>) -> Poll<()> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut io::ReadBuf<'_>,
) -> io::Poll<()> {
match self.project() {
EitherIoProj::Left(l) => l.poll_read(cx, buf),
EitherIoProj::Right(r) => r.poll_read(cx, buf),
}
}
}

impl<L: AsyncWrite, R: AsyncWrite> AsyncWrite for EitherIo<L, R> {
impl<L: io::AsyncWrite, R: io::AsyncWrite> io::AsyncWrite for EitherIo<L, R> {
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> io::Poll<()> {
match self.project() {
EitherIoProj::Left(l) => l.poll_shutdown(cx),
EitherIoProj::Right(r) => r.poll_shutdown(cx),
}
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<()> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> io::Poll<()> {
match self.project() {
EitherIoProj::Left(l) => l.poll_flush(cx),
EitherIoProj::Right(r) => r.poll_flush(cx),
}
}

#[inline]
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<usize> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> io::Poll<usize> {
match self.project() {
EitherIoProj::Left(l) => l.poll_write(cx, buf),
EitherIoProj::Right(r) => r.poll_write(cx, buf),
Expand All @@ -58,8 +76,8 @@ impl<L: AsyncWrite, R: AsyncWrite> AsyncWrite for EitherIo<L, R> {
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[IoSlice<'_>],
) -> Poll<usize> {
buf: &[io::IoSlice<'_>],
) -> io::Poll<usize> {
match self.project() {
EitherIoProj::Left(l) => l.poll_write_vectored(cx, buf),
EitherIoProj::Right(r) => r.poll_write_vectored(cx, buf),
Expand Down
24 changes: 24 additions & 0 deletions linkerd/io/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,30 @@ pub use tokio_util::io::{poll_read_buf, poll_write_buf};

pub type Poll<T> = std::task::Poll<Result<T>>;

// === Peek ===

#[async_trait::async_trait]
pub trait Peek {
/// Receives data on the socket from the remote address to which it is
/// connected, without removing that data from the queue. On success,
/// returns the number of bytes peeked. A return value of zero bytes does not
/// necessarily indicate that the underlying socket has closed.
///
/// Successive calls return the same data. This is accomplished by passing
/// `MSG_PEEK` as a flag to the underlying recv system call.
async fn peek(&self, buf: &mut [u8]) -> Result<usize>;
}

// Special-case a wrapper for TcpStream::peek.
#[async_trait::async_trait]
impl Peek for tokio::net::TcpStream {
async fn peek(&self, buf: &mut [u8]) -> Result<usize> {
tokio::net::TcpStream::peek(self, buf).await
}
}

// === PeerAddr ===

pub trait PeerAddr {
fn peer_addr(&self) -> Result<SocketAddr>;
}
Expand Down
63 changes: 39 additions & 24 deletions linkerd/io/src/prefixed.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
use crate::{IoSlice, PeerAddr, Poll};
use crate::{self as io};
use bytes::{Buf, Bytes};
use pin_project::pin_project;
use std::{cmp, io};
use std::{pin::Pin, task::Context};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, Result};
use std::{cmp, pin::Pin, task::Context};

/// A TcpStream where the initial reads will be served from `prefix`.
#[pin_project]
#[derive(Debug)]
pub struct PrefixedIo<S> {
pub struct PrefixedIo<I> {
prefix: Bytes,

#[pin]
io: S,
io: I,
}

impl<S> PrefixedIo<S> {
pub fn new(prefix: impl Into<Bytes>, io: S) -> Self {
impl<I> PrefixedIo<I> {
pub fn new(prefix: impl Into<Bytes>, io: I) -> Self {
let prefix = prefix.into();
Self { prefix, io }
}
Expand All @@ -26,21 +24,38 @@ impl<S> PrefixedIo<S> {
}
}

impl<S> From<S> for PrefixedIo<S> {
fn from(io: S) -> Self {
impl<I> From<I> for PrefixedIo<I> {
fn from(io: I) -> Self {
Self::new(Bytes::default(), io)
}
}

impl<S: PeerAddr> PeerAddr for PrefixedIo<S> {
#[async_trait::async_trait]
impl<I: Send + Sync> io::Peek for PrefixedIo<I> {
async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
let sz = self.prefix.len().min(buf.len());
if sz == 0 {
return Ok(0);
}

(&mut buf[..sz]).clone_from_slice(&self.prefix[..sz]);
Ok(sz)
}
}

impl<I: io::PeerAddr> io::PeerAddr for PrefixedIo<I> {
#[inline]
fn peer_addr(&self) -> Result<std::net::SocketAddr> {
fn peer_addr(&self) -> io::Result<std::net::SocketAddr> {
self.io.peer_addr()
}
}

impl<S: AsyncRead> AsyncRead for PrefixedIo<S> {
fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<()> {
impl<I: io::AsyncRead> io::AsyncRead for PrefixedIo<I> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut io::ReadBuf<'_>,
) -> io::Poll<()> {
let this = self.project();
// Check the length only once, since looking as the length
// of a Bytes isn't as cheap as the length of a &[u8].
Expand All @@ -58,45 +73,45 @@ impl<S: AsyncRead> AsyncRead for PrefixedIo<S> {
if peeked_len == len {
*this.prefix = Bytes::new();
}
Poll::Ready(Ok(()))
io::Poll::Ready(Ok(()))
}
}
}

impl<S: io::Write> io::Write for PrefixedIo<S> {
impl<I: io::Write> io::Write for PrefixedIo<I> {
#[inline]
fn write(&mut self, buf: &[u8]) -> Result<usize> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.io.write(buf)
}

#[inline]
fn flush(&mut self) -> Result<()> {
fn flush(&mut self) -> io::Result<()> {
self.io.flush()
}
}

impl<S: AsyncWrite> AsyncWrite for PrefixedIo<S> {
impl<I: io::AsyncWrite> io::AsyncWrite for PrefixedIo<I> {
#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> {
self.project().io.poll_shutdown(cx)
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> io::Poll<()> {
self.project().io.poll_flush(cx)
}

#[inline]
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<usize> {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> io::Poll<usize> {
self.project().io.poll_write(cx, buf)
}

#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<usize> {
bufs: &[io::IoSlice<'_>],
) -> io::Poll<usize> {
self.project().io.poll_write_vectored(cx, bufs)
}

Expand Down
Loading