Skip to content

Commit

Permalink
performance: make copy buffer sizes dynamic (#1024)
Browse files Browse the repository at this point in the history
Today, we give each buffer 16k. There are 3 per con. This leads to a lot
of per connection memory usage.

This PR changes each buffer to start with only 1k. This dynamically
scales up.

Currently, the algorithm to scale up is very simple: once we have sent
128k total on the connection, we resize the buffer. This very likely
could be improved.

On a test load opening 10k connections:
Before: 90mb
After: 38mb
  • Loading branch information
howardjohn authored May 6, 2024
1 parent f3ad7d8 commit 89fa6f2
Show file tree
Hide file tree
Showing 7 changed files with 279 additions and 146 deletions.
263 changes: 263 additions & 0 deletions src/copy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
// Copyright Istio Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::proxy::ConnectionResult;
use pin_project_lite::pin_project;
use std::cmp;
use std::future::Future;
use std::io::IoSlice;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use tokio::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::trace;

// BufferedSplitter is a trait to expose splitting an IO object into a buffered reader and a writer
pub trait BufferedSplitter: Unpin {
type R: ResizeBufRead + Unpin;
type W: AsyncWrite + Unpin;
fn split_into_buffered_reader(self) -> (Self::R, Self::W);
}

// Generic BufferedSplitter for anything that can Read/Write.
impl<I> BufferedSplitter for I
where
I: AsyncRead + AsyncWrite + Unpin,
{
type R = BufReader<io::ReadHalf<I>>;
type W = io::WriteHalf<I>;
fn split_into_buffered_reader(self) -> (Self::R, Self::W) {
let (rh, wh) = tokio::io::split(self);
let rb = BufReader::new(rh);
(rb, wh)
}
}

// ResizeBufRead is like AsyncBufRead, but allows triggering a resize.
pub trait ResizeBufRead {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>>;
fn consume(self: Pin<&mut Self>, amt: usize);
fn resize(self: Pin<&mut Self>);
}

// Initially we create a 1k buffer for each connection. Note currently there are 3 buffers per connection.
// Outbound: downstream to app. Upstream HBONE is optimized to avoid.
// Inbound: downstream HBONE, upstream to app. Downstream HBONE can be optimized, but is not yet.
const INITIAL_BUFFER_SIZE: usize = 1024;
// We increase up to 16k for high traffic connections.
// TLS record size max is 16k. But we also have an H2 frame header, so leave a bit of room for that.
const LARGE_BUFFER_SIZE: usize = 16_384 - 64;
// After 128k of data we will trigger a resize from INITIAL to LARGE
// Loosely inspired by https://github.com/golang/go/blame/5122a6796ef98e3453c994c95abd640596540bea/src/crypto/tls/conn.go#L873
const RESIZE_THRESHOLD: u64 = 128 * 1024;

pub async fn copy_bidirectional<A, B>(
downstream: A,
upstream: B,
stats: &ConnectionResult,
) -> Result<(), crate::proxy::Error>
where
A: BufferedSplitter,
B: BufferedSplitter,
{
use tokio::io::AsyncWriteExt;
let (mut rd, mut wd) = downstream.split_into_buffered_reader();
let (mut ru, mut wu) = upstream.split_into_buffered_reader();

let (mut sent, mut received): (u64, u64) = (0, 0);

let downstream_to_upstream = async {
let res = copy_buf(&mut rd, &mut wu, stats, false).await;
trace!(?res, "send");
sent = res?;
wu.shutdown().await
};

let upstream_to_downstream = async {
let res = copy_buf(&mut ru, &mut wd, stats, true).await;
trace!(?res, "recieve");
received = res?;
wd.shutdown().await
};

tokio::try_join!(downstream_to_upstream, upstream_to_downstream)?;

trace!(sent, received, "copy complete");
Ok(())
}

// CopyBuf is a fork of Tokio's same struct, with additional support for resizing and metrics reporting.
#[must_use = "futures do nothing unless you `.await` or poll them"]
struct CopyBuf<'a, R: ?Sized, W: ?Sized> {
send: bool,
reader: &'a mut R,
writer: &'a mut W,
metrics: &'a ConnectionResult,
amt: u64,
}

async fn copy_buf<'a, R, W>(
reader: &'a mut R,
writer: &'a mut W,
metrics: &ConnectionResult,
is_send: bool,
) -> std::io::Result<u64>
where
R: ResizeBufRead + Unpin + ?Sized,
W: tokio::io::AsyncWrite + Unpin + ?Sized,
{
CopyBuf {
send: is_send,
reader,
writer,
metrics,
amt: 0,
}
.await
}

impl<R, W> Future for CopyBuf<'_, R, W>
where
R: ResizeBufRead + Unpin + ?Sized,
W: AsyncWrite + Unpin + ?Sized,
{
type Output = std::io::Result<u64>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let me = &mut *self;
let buffer = ready!(Pin::new(&mut *me.reader).poll_fill_buf(cx))?;
if buffer.is_empty() {
ready!(Pin::new(&mut self.writer).poll_flush(cx))?;
return Poll::Ready(Ok(self.amt));
}

let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, buffer))?;
if i == 0 {
return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
}
if me.send {
me.metrics.increment_send(i as u64);
} else {
me.metrics.increment_recv(i as u64);
}
let old = self.amt;
self.amt += i as u64;

// If we were below the resize threshold before but are now above it, trigger the buffer to resize
if old < RESIZE_THRESHOLD && RESIZE_THRESHOLD <= self.amt {
Pin::new(&mut *self.reader).resize();
}
Pin::new(&mut *self.reader).consume(i);
}
}
}

// BufReader is a fork of Tokio's type with resize support
pin_project! {
pub struct BufReader<R> {
#[pin]
inner: R,
buf: Box<[u8]>,
pos: usize,
cap: usize,
}
}

impl<R: AsyncRead> BufReader<R> {
/// Creates a new `BufReader` with a default buffer capacity. The default is currently INITIAL_BUFFER_SIZE
pub fn new(inner: R) -> Self {
let buffer = vec![0; INITIAL_BUFFER_SIZE];
Self {
inner,
buf: buffer.into_boxed_slice(),
pos: 0,
cap: 0,
}
}

fn get_ref(&self) -> &R {
&self.inner
}

fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
self.project().inner
}
}

impl<R: AsyncRead> ResizeBufRead for BufReader<R> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let me = self.project();

// If we've reached the end of our internal buffer then we need to fetch
// some more data from the underlying reader.
// Branch using `>=` instead of the more correct `==`
// to tell the compiler that the pos..cap slice is always valid.
if *me.pos >= *me.cap {
debug_assert!(*me.pos == *me.cap);
let mut buf = tokio::io::ReadBuf::new(me.buf);
ready!(me.inner.poll_read(cx, &mut buf))?;
*me.cap = buf.filled().len();
*me.pos = 0;
}
Poll::Ready(Ok(&me.buf[*me.pos..*me.cap]))
}

fn consume(self: Pin<&mut Self>, amt: usize) {
let me = self.project();
*me.pos = cmp::min(*me.pos + amt, *me.cap);
}

fn resize(self: Pin<&mut Self>) {
let me = self.project();
// If we don't hit this, we somehow called resize twice unexpectedly
debug_assert_eq!(me.buf.len(), INITIAL_BUFFER_SIZE);
// Make a new buffer of the large size, and swap it into place
let mut now = vec![0u8; LARGE_BUFFER_SIZE].into_boxed_slice();
std::mem::swap(me.buf, &mut now);
// Now copy over any data from the old buffer.
me.buf[0..now.len()].copy_from_slice(&now);
trace!("resized buffer to {}", LARGE_BUFFER_SIZE)
}
}

impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.get_pin_mut().poll_write(cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_pin_mut().poll_flush(cx)
}

fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_pin_mut().poll_shutdown(cx)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.get_pin_mut().poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.get_ref().is_write_vectored()
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub mod assertions;
pub mod baggage;
pub mod cert_fetcher;
pub mod config;
pub mod copy;
pub mod dns;
pub mod hyper_util;
pub mod identity;
Expand Down
18 changes: 6 additions & 12 deletions src/proxy/h2_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
// * https://github.com/cloudflare/pingora/blob/main/pingora-core/src/protocols/http/v2/client.rs
// * https://github.com/hyperium/hyper/blob/master/src/proto/h2/client.rs

use crate::config;
use crate::proxy::Error;
use crate::{config, copy};

use bytes::Buf;
use bytes::Bytes;
Expand All @@ -32,7 +32,7 @@ use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;

use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::sync::watch::Receiver;
Expand All @@ -59,7 +59,7 @@ pub struct H2StreamWriteHalf {
active_count: Arc<AtomicU16>,
}

impl crate::socket::BufferedSplitter for H2Stream {
impl crate::copy::BufferedSplitter for H2Stream {
type R = H2StreamReadHalf;
type W = H2StreamWriteHalf;
fn split_into_buffered_reader(self) -> (H2StreamReadHalf, H2StreamWriteHalf) {
Expand Down Expand Up @@ -111,7 +111,7 @@ impl Drop for H2StreamWriteHalf {
}
}

impl AsyncBufRead for H2StreamReadHalf {
impl copy::ResizeBufRead for H2StreamReadHalf {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
const EOF: Poll<std::io::Result<&[u8]>> = Poll::Ready(Ok(&[]));
let this = self.get_mut();
Expand Down Expand Up @@ -150,15 +150,9 @@ impl AsyncBufRead for H2StreamReadHalf {
fn consume(mut self: Pin<&mut Self>, amt: usize) {
self.as_mut().buf.advance(amt)
}
}

impl AsyncRead for H2StreamReadHalf {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_read_buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
panic!("H2StreamReadHalf should never be read directly; use poll_fill_buf");
fn resize(self: Pin<&mut Self>) {
// NOP, we don't need to resize as we are abstracting the h2 buffer
}
}

Expand Down
6 changes: 3 additions & 3 deletions src/proxy/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::socket::to_canonical;
use crate::state::service::Service;
use crate::state::workload::address::Address;
use crate::state::workload::application_tunnel::Protocol as AppProtocol;
use crate::{assertions, proxy, socket, strng, tls};
use crate::{assertions, copy, proxy, strng, tls};

use crate::state::workload::{self, NetworkAddress, Workload};
use crate::state::DemandProxyState;
Expand Down Expand Up @@ -372,7 +372,7 @@ impl Inbound {
hyper::upgrade::on(req)
.map_err(Error::NoUpgrade)
.and_then(|upgraded| async move {
socket::copy_bidirectional(
copy::copy_bidirectional(
&mut ::hyper_util::rt::TokioIo::new(upgraded),
&mut stream,
&result_tracker,
Expand All @@ -388,7 +388,7 @@ impl Inbound {
super::write_proxy_protocol(&mut stream, (src, dst), src_id)
.instrument(trace_span!("proxy protocol"))
.await?;
socket::copy_bidirectional(
copy::copy_bidirectional(
&mut ::hyper_util::rt::TokioIo::new(upgraded),
&mut stream,
&result_tracker,
Expand Down
4 changes: 2 additions & 2 deletions src/proxy/inbound_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::proxy::metrics::Reporter;
use crate::proxy::Error;
use crate::proxy::{metrics, util, ProxyInputs};
use crate::state::workload::NetworkAddress;
use crate::{assertions, rbac, strng};
use crate::{assertions, copy, rbac, strng};
use crate::{proxy, socket};

pub(super) struct InboundPassthrough {
Expand Down Expand Up @@ -237,7 +237,7 @@ impl InboundPassthrough {
.map_err(Error::ConnectionFailed)?;

trace!(%source_addr, destination=%dest_addr, component="inbound plaintext", "connected");
socket::copy_bidirectional(&mut inbound_stream, &mut outbound, &result_tracker).await
copy::copy_bidirectional(&mut inbound_stream, &mut outbound, &result_tracker).await
};

let res = conn_guard.handle_connection(send).await;
Expand Down
6 changes: 3 additions & 3 deletions src/proxy/outbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ use crate::state::service::ServiceDescription;
use crate::state::workload::gatewayaddress::Destination;
use crate::state::workload::{address::Address, NetworkAddress, Protocol, Workload};
use crate::strng::Strng;
use crate::{assertions, proxy, socket, strng};
use crate::{assertions, copy, proxy, socket, strng};

pub struct Outbound {
pi: ProxyInputs,
Expand Down Expand Up @@ -279,7 +279,7 @@ impl OutboundConnection {

let upgraded = Box::pin(self.build_hbone_request(remote_addr, &req)).await?;

socket::copy_bidirectional(stream, upgraded, connection_stats).await
copy::copy_bidirectional(stream, upgraded, connection_stats).await
}

async fn build_hbone_request(
Expand Down Expand Up @@ -353,7 +353,7 @@ impl OutboundConnection {
super::freebind_connect(local, req.gateway, self.pi.socket_factory.as_ref()).await?;

// Proxying data between downstream and upstream
socket::copy_bidirectional(stream, &mut outbound, connection_stats).await
copy::copy_bidirectional(stream, &mut outbound, connection_stats).await
}

fn conn_metrics_from_request(req: &Request) -> ConnectionOpen {
Expand Down
Loading

0 comments on commit 89fa6f2

Please sign in to comment.