Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): amortize many ready messages into fewer, larger buffers #1423

Merged
merged 2 commits into from
Sep 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 163 additions & 33 deletions tonic/src/codec/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ use std::{
};
use tokio_stream::{Stream, StreamExt};

use fuse::Fuse;

pub(super) const BUFFER_SIZE: usize = 8 * 1024;
const YIELD_THRESHOLD: usize = 32 * 1024;

pub(crate) fn encode_server<T, U>(
encoder: T,
Expand All @@ -24,7 +27,7 @@ where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
let stream = encode(
let stream = EncodedBytes::new(
encoder,
source,
compression_encoding,
Expand All @@ -45,7 +48,7 @@ where
T: Encoder<Error = Status>,
U: Stream<Item = T::Item>,
{
let stream = encode(
let stream = EncodedBytes::new(
encoder,
source.map(Ok),
compression_encoding,
Expand All @@ -55,44 +58,115 @@ where
EncodeBody::new_client(stream)
}

fn encode<T, U>(
mut encoder: T,
source: U,
/// Combinator for efficient encoding of messages into reasonably sized buffers.
/// EncodedBytes encodes ready messages from its delegate stream into a BytesMut,
/// splitting off and yielding a buffer when either:
/// * The delegate stream polls as not ready, or
/// * The encoded buffer surpasses YIELD_THRESHOLD.
#[pin_project(project = EncodedBytesProj)]
#[derive(Debug)]
pub(crate) struct EncodedBytes<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
#[pin]
source: Fuse<U>,
LucioFranco marked this conversation as resolved.
Show resolved Hide resolved
encoder: T,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> impl Stream<Item = Result<Bytes, Status>>
buf: BytesMut,
uncompression_buf: BytesMut,
}

impl<T, U> EncodedBytes<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
let mut buf = BytesMut::with_capacity(BUFFER_SIZE);
fn new(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> Self {
let buf = BytesMut::with_capacity(BUFFER_SIZE);

let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable
{
None
} else {
compression_encoding
};
let compression_encoding =
if compression_override == SingleMessageCompressionOverride::Disable {
None
} else {
compression_encoding
};

let mut uncompression_buf = if compression_encoding.is_some() {
BytesMut::with_capacity(BUFFER_SIZE)
} else {
BytesMut::new()
};
let uncompression_buf = if compression_encoding.is_some() {
BytesMut::with_capacity(BUFFER_SIZE)
} else {
BytesMut::new()
};

source.map(move |result| {
let item = result?;
return EncodedBytes {
source: Fuse::new(source),
encoder,
compression_encoding,
max_message_size,
buf,
uncompression_buf,
};
}
}

encode_item(
&mut encoder,
&mut buf,
&mut uncompression_buf,
impl<T, U> Stream for EncodedBytes<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
type Item = Result<Bytes, Status>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let EncodedBytesProj {
mut source,
encoder,
compression_encoding,
max_message_size,
item,
)
})
buf,
uncompression_buf,
} = self.project();

loop {
match source.as_mut().poll_next(cx) {
Poll::Pending if buf.is_empty() => {
return Poll::Pending;
}
Poll::Ready(None) if buf.is_empty() => {
return Poll::Ready(None);
}
Poll::Pending | Poll::Ready(None) => {
return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
}
Poll::Ready(Some(Ok(item))) => {
if let Err(status) = encode_item(
encoder,
buf,
uncompression_buf,
*compression_encoding,
*max_message_size,
item,
) {
return Poll::Ready(Some(Err(status)));
}

if buf.len() >= YIELD_THRESHOLD {
return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
}
}
Poll::Ready(Some(Err(status))) => {
return Poll::Ready(Some(Err(status)));
}
}
}
}
}

fn encode_item<T>(
Expand All @@ -102,10 +176,12 @@ fn encode_item<T>(
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
item: T::Item,
) -> Result<Bytes, Status>
) -> Result<(), Status>
where
T: Encoder<Error = Status>,
{
let offset = buf.len();

buf.reserve(HEADER_SIZE);
unsafe {
buf.advance_mut(HEADER_SIZE);
Expand All @@ -129,14 +205,14 @@ where
}

// now that we know length, we can write the header
finish_encoding(compression_encoding, max_message_size, buf)
finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
}

fn finish_encoding(
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
buf: &mut BytesMut,
) -> Result<Bytes, Status> {
buf: &mut [u8],
) -> Result<(), Status> {
let len = buf.len() - HEADER_SIZE;
let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
if len > limit {
Expand All @@ -160,7 +236,7 @@ fn finish_encoding(
buf.put_u32(len as u32);
}

Ok(buf.split_to(len + HEADER_SIZE).freeze())
Ok(())
}

#[derive(Debug)]
Expand Down Expand Up @@ -269,3 +345,57 @@ where
Poll::Ready(self.project().state.trailers())
}
}

mod fuse {
use std::{
pin::Pin,
task::{ready, Context, Poll},
};

use tokio_stream::Stream;

/// Stream for the [`fuse`](super::StreamExt::fuse) method.
#[derive(Debug)]
#[pin_project::pin_project]
#[must_use = "streams do nothing unless polled"]
pub(crate) struct Fuse<St> {
#[pin]
stream: St,
done: bool,
}

impl<St> Fuse<St> {
pub(crate) fn new(stream: St) -> Self {
Self {
stream,
done: false,
}
}
}

impl<S: Stream> Stream for Fuse<S> {
type Item = S::Item;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
let this = self.project();

if *this.done {
return Poll::Ready(None);
}

let item = ready!(this.stream.poll_next(cx));
if item.is_none() {
*this.done = true;
}
Poll::Ready(item)
}

fn size_hint(&self) -> (usize, Option<usize>) {
if self.done {
(0, Some(0))
} else {
self.stream.size_hint()
}
}
}
}