Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): amortize many ready messages into fewer, larger buffers #1423

Merged
merged 2 commits into from
Sep 1, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 110 additions & 39 deletions tonic/src/codec/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::compression::{compress, CompressionEncoding, SingleMessageCompression
use super::{EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
use crate::{Code, Status};
use bytes::{BufMut, Bytes, BytesMut};
use futures_util::{ready, StreamExt, TryStream, TryStreamExt};
use futures_util::{ready, stream::Fuse, StreamExt, TryStreamExt};
use http::HeaderMap;
use http_body::Body;
use pin_project::pin_project;
Expand All @@ -13,6 +13,7 @@ use std::{
use tokio_stream::Stream;

pub(super) const BUFFER_SIZE: usize = 8 * 1024;
const YIELD_THRESHOLD: usize = 32 * 1024;

pub(crate) fn encode_server<T, U>(
encoder: T,
Expand All @@ -25,15 +26,13 @@ where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
let stream = encode(
let stream = EncodedBytes::new(
encoder,
source,
compression_encoding,
compression_override,
max_message_size,
)
.into_stream();

);
EncodeBody::new_server(stream)
}

Expand All @@ -47,55 +46,125 @@ where
T: Encoder<Error = Status>,
U: Stream<Item = T::Item>,
{
let stream = encode(
let stream = EncodedBytes::new(
encoder,
source.map(Ok),
compression_encoding,
SingleMessageCompressionOverride::default(),
max_message_size,
)
.into_stream();
);
EncodeBody::new_client(stream)
}

fn encode<T, U>(
mut encoder: T,
source: U,
/// Combinator for efficient encoding of messages into reasonably sized buffers.
/// EncodedBytes encodes ready messages from its delegate stream into a BytesMut,
/// splitting off and yielding a buffer when either:
/// * The delegate stream polls as not ready, or
/// * The encoded buffer surpasses YIELD_THRESHOLD.
#[pin_project(project = EncodedBytesProj)]
#[derive(Debug)]
pub(crate) struct EncodedBytes<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
#[pin]
source: Fuse<U>,
LucioFranco marked this conversation as resolved.
Show resolved Hide resolved
encoder: T,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> impl TryStream<Ok = Bytes, Error = Status>
buf: BytesMut,
uncompression_buf: BytesMut,
}

impl<T, U> EncodedBytes<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
let mut buf = BytesMut::with_capacity(BUFFER_SIZE);
fn new(
encoder: T,
source: U,
compression_encoding: Option<CompressionEncoding>,
compression_override: SingleMessageCompressionOverride,
max_message_size: Option<usize>,
) -> Self {
let buf = BytesMut::with_capacity(BUFFER_SIZE);

let compression_encoding = if compression_override == SingleMessageCompressionOverride::Disable
{
None
} else {
compression_encoding
};
let compression_encoding =
if compression_override == SingleMessageCompressionOverride::Disable {
None
} else {
compression_encoding
};

let mut uncompression_buf = if compression_encoding.is_some() {
BytesMut::with_capacity(BUFFER_SIZE)
} else {
BytesMut::new()
};
let uncompression_buf = if compression_encoding.is_some() {
BytesMut::with_capacity(BUFFER_SIZE)
} else {
BytesMut::new()
};

return EncodedBytes {
source: source.fuse(),
encoder,
compression_encoding,
max_message_size,
buf,
uncompression_buf,
};
}
}

source.map(move |result| {
let item = result?;
impl<T, U> Stream for EncodedBytes<T, U>
where
T: Encoder<Error = Status>,
U: Stream<Item = Result<T::Item, Status>>,
{
type Item = Result<Bytes, Status>;

encode_item(
&mut encoder,
&mut buf,
&mut uncompression_buf,
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let EncodedBytesProj {
mut source,
encoder,
compression_encoding,
max_message_size,
item,
)
})
buf,
uncompression_buf,
} = self.project();

loop {
match source.as_mut().poll_next(cx) {
Poll::Pending if buf.is_empty() => {
return Poll::Pending;
}
Poll::Ready(None) if buf.is_empty() => {
return Poll::Ready(None);
}
Poll::Pending | Poll::Ready(None) => {
return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
}
Poll::Ready(Some(Ok(item))) => {
if let Err(status) = encode_item(
encoder,
buf,
uncompression_buf,
*compression_encoding,
*max_message_size,
item,
) {
return Poll::Ready(Some(Err(status)));
}

if buf.len() >= YIELD_THRESHOLD {
return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
}
}
Poll::Ready(Some(Err(status))) => {
return Poll::Ready(Some(Err(status)));
}
}
}
}
}

fn encode_item<T>(
Expand All @@ -105,10 +174,12 @@ fn encode_item<T>(
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
item: T::Item,
) -> Result<Bytes, Status>
) -> Result<(), Status>
where
T: Encoder<Error = Status>,
{
let offset = buf.len();

buf.reserve(HEADER_SIZE);
unsafe {
buf.advance_mut(HEADER_SIZE);
Expand All @@ -132,14 +203,14 @@ where
}

// now that we know length, we can write the header
finish_encoding(compression_encoding, max_message_size, buf)
finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
}

fn finish_encoding(
compression_encoding: Option<CompressionEncoding>,
max_message_size: Option<usize>,
buf: &mut BytesMut,
) -> Result<Bytes, Status> {
buf: &mut [u8],
) -> Result<(), Status> {
let len = buf.len() - HEADER_SIZE;
let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
if len > limit {
Expand All @@ -163,7 +234,7 @@ fn finish_encoding(
buf.put_u32(len as u32);
}

Ok(buf.split_to(len + HEADER_SIZE).freeze())
Ok(())
}

#[derive(Debug)]
Expand Down