Skip to content

Commit

Permalink
Implement outbound drain
Browse files Browse the repository at this point in the history
Signed-off-by: Benjamin Leggett <benjamin.leggett@solo.io>
  • Loading branch information
bleggett committed Apr 9, 2024
1 parent 2511586 commit dd34c38
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 26 deletions.
2 changes: 1 addition & 1 deletion src/proxy/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::fmt;
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::time::Instant;
use std::time::{Instant};

use bytes::Bytes;
use drain::Watch;
Expand Down
7 changes: 6 additions & 1 deletion src/proxy/inbound_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@ impl InboundPassthrough {
}

pub(super) async fn run(self) {
let (out_drain_signal, out_drain) = drain::channel();
let accept = async move {
loop {
// Asynchronously wait for an inbound socket.
let socket = self.listener.accept().await;
let pi = self.pi.clone();
let inner_drain = out_drain.clone();

let connection_manager = self.pi.connection_manager.clone();
match socket {
Expand All @@ -77,6 +79,7 @@ impl InboundPassthrough {
socket::to_canonical(remote),
stream,
connection_manager,
inner_drain,
)
.await
{
Expand All @@ -99,6 +102,7 @@ impl InboundPassthrough {
tokio::select! {
res = accept => { res }
_ = self.drain.signaled() => {
out_drain_signal.drain().await;
info!("inbound passthrough drained");
}
}
Expand All @@ -109,6 +113,7 @@ impl InboundPassthrough {
source: SocketAddr,
mut inbound: TcpStream,
connection_manager: ConnectionManager,
outbound_conn_drain: Watch,
) -> Result<(), Error> {
let orig = socket::orig_dst_addr_or_default(&inbound);
// Check if it is a recursive call when proxy mode is Node.
Expand Down Expand Up @@ -138,7 +143,7 @@ impl InboundPassthrough {
// Spoofing the source IP only works when the destination or the source are on our node.
// In this case, the source and the destination might both be remote, so we need to disable it.
oc.pi.cfg.enable_original_source = Some(false);
return oc.proxy_to(inbound, source, orig, false).await;
return oc.proxy_to(inbound, source, orig, false, outbound_conn_drain).await;
}

// We enforce RBAC only for non-hairpin cases. This is because we may not be able to properly
Expand Down
52 changes: 42 additions & 10 deletions src/proxy/outbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use drain::Watch;
use http_body_util::Empty;
use hyper::header::FORWARDED;
use hyper::StatusCode;
use hyper::client::conn::http2;

use tokio::net::{TcpListener, TcpStream};

Expand Down Expand Up @@ -74,11 +75,19 @@ impl Outbound {
}

pub(super) async fn run(self) {
// Since we are spawning autonomous tasks to handle outbound connections for a single workload,
// we can have situations where the workload is deleted, but a task is still "stuck"
// waiting for a server response stream on a HTTP/2 connection or whatnot.
//
// So use a drain to nuke tasks that might be stuck sending.
let (sub_drain_signal, sub_drain) = drain::channel();
let accept = async move {
loop {
// Asynchronously wait for an inbound socket.
let socket = self.listener.accept().await;
let start_outbound_instant = Instant::now();
let outbound_drain = sub_drain.clone();
let outer_conn_drain = sub_drain.clone();
match socket {
Ok((stream, _remote)) => {
let mut oc = OutboundConnection {
Expand All @@ -88,11 +97,20 @@ impl Outbound {
let span = info_span!("outbound", id=%oc.id);
tokio::spawn(
(async move {
let res = oc.proxy(stream).await;
match res {
Ok(_) => info!(dur=?start_outbound_instant.elapsed(), "complete"),
Err(e) => warn!(dur=?start_outbound_instant.elapsed(), err=%e, "failed")
};
debug!(dur=?start_outbound_instant.elapsed(), id=%oc.id, "BML: outbound spawn START");
// Since this task is spawned, make sure we are guaranteed to terminate
tokio::select! {
_ = outbound_drain.signaled() => {
debug!("outbound drain signaled");
}
res = oc.proxy(stream, outer_conn_drain.clone()) => {
match res {
Ok(_) => info!(dur=?start_outbound_instant.elapsed(), "complete"),
Err(e) => warn!(dur=?start_outbound_instant.elapsed(), err=%e, "failed")
};
}
}
debug!(dur=?start_outbound_instant.elapsed(), id=%oc.id, "BML: outbound spawn DONE");
})
.instrument(span),
);
Expand All @@ -113,6 +131,8 @@ impl Outbound {
tokio::select! {
res = accept => { res }
_ = self.drain.signaled() => {
debug!("outbound drained, dropping any outbound connections");
sub_drain_signal.drain().await;
info!("outbound drained");
}
}
Expand All @@ -125,10 +145,10 @@ pub(super) struct OutboundConnection {
}

impl OutboundConnection {
async fn proxy(&mut self, stream: TcpStream) -> Result<(), Error> {
async fn proxy(&mut self, stream: TcpStream, outer_conn_drain: Watch) -> Result<(), Error> {
let peer = socket::to_canonical(stream.peer_addr().expect("must receive peer addr"));
let orig_dst_addr = socket::orig_dst_addr_or_default(&stream);
self.proxy_to(stream, peer, orig_dst_addr, false).await
self.proxy_to(stream, peer, orig_dst_addr, false, outer_conn_drain).await
}

pub async fn proxy_to(
Expand All @@ -137,6 +157,7 @@ impl OutboundConnection {
remote_addr: SocketAddr,
orig_dst_addr: SocketAddr,
block_passthrough: bool,
outer_conn_drain: Watch,
) -> Result<(), Error> {
if self.pi.cfg.proxy_mode == ProxyMode::Shared
&& Some(orig_dst_addr.ip()) == self.pi.cfg.local_ip
Expand Down Expand Up @@ -273,7 +294,7 @@ impl OutboundConnection {
// in the pool.
let connect = async {
let mut builder =
hyper::client::conn::http2::Builder::new(hyper_util::TokioExecutor);
http2::Builder::new(hyper_util::TokioExecutor);
let builder = builder
.initial_stream_window_size(self.pi.cfg.window_size)
.max_frame_size(self.pi.cfg.frame_size)
Expand Down Expand Up @@ -302,8 +323,14 @@ impl OutboundConnection {
.map_err(Error::HttpHandshake)?;
// spawn a task to poll the connection and drive the HTTP state
tokio::spawn(async move {
if let Err(e) = connection.await {
error!("Error in HBONE connection handshake: {:?}", e);
tokio::select! {
_ = outer_conn_drain.signaled() => {
debug!("draining outer connection");
return
}
res = connection=> {
debug!("done with HBONE connection handshake: {:?}", res);
}
}
});
Ok(request_sender)
Expand All @@ -326,7 +353,12 @@ impl OutboundConnection {
.body(Empty::<Bytes>::new())
.expect("builder with known status code should not fail");

debug!("outbound - connection send START");
// There are scenarios (upstream hangup, etc) where this "send" will simply get stuck.
// As in, stream processing deadlocks, and `send_request` never resolves to anything.
// Probably related to https://github.com/hyperium/hyper/issues/3623
let response = connection.send_request(request).await?;
debug!("outbound - connection send END");

let code = response.status();
if code != 200 {
Expand Down
20 changes: 13 additions & 7 deletions src/proxy/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl Connection {
&mut self,
req: Request<Empty<Bytes>>,
) -> impl Future<Output = hyper::Result<Response<Incoming>>> {
self.0 .0.send_request(req)
self.0.0.send_request(req)
}
}

Expand Down Expand Up @@ -146,12 +146,18 @@ impl Pool {
request_sender
}
// Connect won, checkout can just be dropped.
Either::Right((Err(err), checkout)) => match err {
// Connect won but we already had an in-flight connection, so use that.
Error::PoolAlreadyConnecting => checkout.await?,
// Some other connection error
err => return Err(err),
},
Either::Right((Err(err), checkout)) => {
debug!(
?key,
"connect won, but wait for existing pooled connection to establish"
);
match err {
// Connect won but we already had an in-flight connection, so use that.
Error::PoolAlreadyConnecting => checkout.await?,
// Some other connection error
err => return Err(err),
}
}
};

Ok(Connection(request_sender))
Expand Down
24 changes: 17 additions & 7 deletions src/proxy/socks5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ impl Socks5 {
}

pub async fn run(self) {
let (out_drain_signal, out_drain) = drain::channel();
let accept = async move {
loop {
// Asynchronously wait for an inbound socket.
let socket = self.listener.accept().await;
let inner_drain = out_drain.clone();
match socket {
Ok((stream, remote)) => {
info!("accepted outbound connection from {}", remote);
Expand All @@ -69,7 +71,7 @@ impl Socks5 {
id: TraceParent::new(),
};
tokio::spawn(async move {
if let Err(err) = handle(oc, stream).await {
if let Err(err) = handle(oc, stream, inner_drain).await {
log::error!("handshake error: {}", err);
}
});
Expand All @@ -87,6 +89,7 @@ impl Socks5 {
tokio::select! {
res = accept => { res }
_ = self.drain.signaled() => {
out_drain_signal.drain().await;
info!("socks5 drained");
}
}
Expand All @@ -97,7 +100,7 @@ impl Socks5 {
// sufficient to integrate with common clients:
// - only unauthenticated requests
// - only CONNECT, with IPv4 or IPv6
async fn handle(mut oc: OutboundConnection, mut stream: TcpStream) -> Result<(), anyhow::Error> {
async fn handle(mut oc: OutboundConnection, mut stream: TcpStream, out_drain: Watch) -> Result<(), anyhow::Error> {
// Version(5), Number of auth methods
let mut version = [0u8; 2];
stream.read_exact(&mut version).await?;
Expand Down Expand Up @@ -190,11 +193,18 @@ async fn handle(mut oc: OutboundConnection, mut stream: TcpStream) -> Result<(),

info!("accepted connection from {remote_addr} to {host}");
tokio::spawn(async move {
let res = oc.proxy_to(stream, remote_addr, host, true).await;
match res {
Ok(_) => {}
Err(ref e) => warn!("outbound proxy failed: {}", e),
};
let outer_conn_drain = out_drain.clone();
tokio::select! {
_ = out_drain.signaled() => {
info!("socks drain signaled");
}
res = oc.proxy_to(stream, remote_addr, host, true, outer_conn_drain) => {
match res {
Ok(_) => {}
Err(ref e) => warn!("outbound proxy failed: {}", e),
};
}
}
});
Ok(())
}

0 comments on commit dd34c38

Please sign in to comment.