Skip to content

Commit 463a3ff

Browse files
authored
Merge pull request #3285 from spinframework/max-number-requests
Set a limit on max number of concurrent outbound http requests
2 parents 153fd34 + a472018 commit 463a3ff

File tree

5 files changed

+144
-32
lines changed

5 files changed

+144
-32
lines changed

crates/factor-outbound-http/src/lib.rs

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use spin_factors::{
2222
anyhow, ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors,
2323
SelfInstanceBuilder,
2424
};
25+
use tokio::sync::Semaphore;
2526
use wasmtime_wasi_http::WasiHttpCtx;
2627

2728
pub use wasmtime_wasi_http::{
@@ -53,13 +54,15 @@ impl Factor for OutboundHttpFactor {
5354
&self,
5455
mut ctx: ConfigureAppContext<T, Self>,
5556
) -> anyhow::Result<Self::AppState> {
56-
let connection_pooling = ctx
57-
.take_runtime_config()
58-
.unwrap_or_default()
59-
.connection_pooling;
57+
let config = ctx.take_runtime_config().unwrap_or_default();
6058
Ok(AppState {
61-
wasi_http_clients: wasi::HttpClients::new(connection_pooling),
62-
connection_pooling,
59+
wasi_http_clients: wasi::HttpClients::new(config.connection_pooling_enabled),
60+
connection_pooling_enabled: config.connection_pooling_enabled,
61+
concurrent_outbound_connections_semaphore: config
62+
.max_concurrent_connections
63+
// Permit count is the max concurrent connections + 1.
64+
// i.e., 0 concurrent connections means 1 total connection.
65+
.map(|n| Arc::new(Semaphore::new(n + 1))),
6366
})
6467
}
6568

@@ -80,7 +83,11 @@ impl Factor for OutboundHttpFactor {
8083
request_interceptor: None,
8184
spin_http_client: None,
8285
wasi_http_clients: ctx.app_state().wasi_http_clients.clone(),
83-
connection_pooling: ctx.app_state().connection_pooling,
86+
connection_pooling_enabled: ctx.app_state().connection_pooling_enabled,
87+
concurrent_outbound_connections_semaphore: ctx
88+
.app_state()
89+
.concurrent_outbound_connections_semaphore
90+
.clone(),
8491
})
8592
}
8693
}
@@ -94,7 +101,7 @@ pub struct InstanceState {
94101
request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
95102
// Connection-pooling client for 'fermyon:spin/http' interface
96103
//
97-
// TODO: We could move this to `AppState` to like the
104+
// TODO: We could move this to `AppState` like the
98105
// `wasi:http/outgoing-handler` pool for consistency, although it's probably
99106
// not a high priority given that `fermyon:spin/http` is deprecated anyway.
100107
spin_http_client: Option<reqwest::Client>,
@@ -103,7 +110,10 @@ pub struct InstanceState {
103110
// This is a clone of `AppState::wasi_http_clients`, meaning it is shared
104111
// among all instances of the app.
105112
wasi_http_clients: wasi::HttpClients,
106-
connection_pooling: bool,
113+
/// Whether connection pooling is enabled for this instance.
114+
connection_pooling_enabled: bool,
115+
/// A semaphore to limit the number of concurrent outbound connections.
116+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
107117
}
108118

109119
impl InstanceState {
@@ -185,5 +195,8 @@ impl std::fmt::Display for SelfRequestOrigin {
185195
pub struct AppState {
186196
// Connection pooling clients for `wasi:http/outgoing-handler` interface
187197
wasi_http_clients: wasi::HttpClients,
188-
connection_pooling: bool,
198+
/// Whether connection pooling is enabled for this app.
199+
connection_pooling_enabled: bool,
200+
/// A semaphore to limit the number of concurrent outbound connections.
201+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
189202
}

crates/factor-outbound-http/src/runtime_config.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ pub mod spin;
55
#[derive(Debug)]
66
pub struct RuntimeConfig {
77
/// If true, enable connection pooling and reuse.
8-
pub connection_pooling: bool,
8+
pub connection_pooling_enabled: bool,
9+
/// If set, limits the number of concurrent outbound connections.
10+
pub max_concurrent_connections: Option<usize>,
911
}
1012

1113
impl Default for RuntimeConfig {
1214
fn default() -> Self {
1315
Self {
14-
connection_pooling: true,
16+
connection_pooling_enabled: true,
17+
max_concurrent_connections: None,
1518
}
1619
}
1720
}

crates/factor-outbound-http/src/runtime_config/spin.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,17 @@ use spin_factors::runtime_config::toml::GetTomlValue;
66
/// Expects table to be in the format:
77
/// ```toml
88
/// [outbound_http]
9-
/// connection_pooling = true
9+
/// connection_pooling = true # optional, defaults to true
10+
/// max_concurrent_requests = 10 # optional, defaults to unlimited
1011
/// ```
1112
pub fn config_from_table(
1213
table: &impl GetTomlValue,
1314
) -> anyhow::Result<Option<super::RuntimeConfig>> {
1415
if let Some(outbound_http) = table.get("outbound_http") {
16+
let outbound_http_toml = outbound_http.clone().try_into::<OutboundHttpToml>()?;
1517
Ok(Some(super::RuntimeConfig {
16-
connection_pooling: outbound_http
17-
.clone()
18-
.try_into::<OutboundHttpToml>()?
19-
.connection_pooling,
18+
connection_pooling_enabled: outbound_http_toml.connection_pooling,
19+
max_concurrent_connections: outbound_http_toml.max_concurrent_requests,
2020
}))
2121
} else {
2222
Ok(None)
@@ -28,4 +28,6 @@ pub fn config_from_table(
2828
struct OutboundHttpToml {
2929
#[serde(default)]
3030
connection_pooling: bool,
31+
#[serde(default)]
32+
max_concurrent_requests: Option<usize>,
3133
}

crates/factor-outbound-http/src/spin.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ impl spin_http::Host for crate::InstanceState {
2121
if !req.params.is_empty() {
2222
tracing::warn!("HTTP params field is deprecated");
2323
}
24-
2524
let req_url = if !uri.starts_with('/') {
2625
// Absolute URI
2726
let is_allowed = self
@@ -92,13 +91,21 @@ impl spin_http::Host for crate::InstanceState {
9291
// in a single component execution
9392
let client = self.spin_http_client.get_or_insert_with(|| {
9493
let mut builder = reqwest::Client::builder();
95-
if !self.connection_pooling {
94+
if !self.connection_pooling_enabled {
9695
builder = builder.pool_max_idle_per_host(0);
9796
}
9897
builder.build().unwrap()
9998
});
10099

100+
// If we're limiting concurrent outbound requests, acquire a permit
101+
// Note: since we don't have access to the underlying connection, we can only
102+
// limit the number of concurrent requests, not connections.
103+
let permit = match &self.concurrent_outbound_connections_semaphore {
104+
Some(s) => s.acquire().await.ok(),
105+
None => None,
106+
};
101107
let resp = client.execute(req).await.map_err(log_reqwest_error)?;
108+
drop(permit);
102109

103110
tracing::trace!("Returning response from outbound request to {req_url}");
104111
span.record("http.response.status_code", resp.status().as_u16());

crates/factor-outbound-http/src/wasi.rs

Lines changed: 100 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta
2828
use tokio::{
2929
io::{AsyncRead, AsyncWrite, ReadBuf},
3030
net::TcpStream,
31+
sync::{OwnedSemaphorePermit, Semaphore},
3132
time::timeout,
3233
};
3334
use tokio_rustls::client::TlsStream;
@@ -90,6 +91,9 @@ impl p3::WasiHttpCtx for InstanceState {
9091
self_request_origin: self.self_request_origin.clone(),
9192
blocked_networks: self.blocked_networks.clone(),
9293
http_clients: self.wasi_http_clients.clone(),
94+
concurrent_outbound_connections_semaphore: self
95+
.concurrent_outbound_connections_semaphore
96+
.clone(),
9397
};
9498
let config = OutgoingRequestConfig {
9599
use_tls: request.uri().scheme() == Some(&Scheme::HTTPS),
@@ -282,6 +286,10 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
282286
self_request_origin: self.state.self_request_origin.clone(),
283287
blocked_networks: self.state.blocked_networks.clone(),
284288
http_clients: self.state.wasi_http_clients.clone(),
289+
concurrent_outbound_connections_semaphore: self
290+
.state
291+
.concurrent_outbound_connections_semaphore
292+
.clone(),
285293
};
286294
Ok(HostFutureIncomingResponse::Pending(
287295
wasmtime_wasi::runtime::spawn(
@@ -307,6 +315,7 @@ struct RequestSender {
307315
self_request_origin: Option<SelfRequestOrigin>,
308316
request_interceptor: Option<Arc<dyn OutboundHttpInterceptor>>,
309317
http_clients: HttpClients,
318+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
310319
}
311320

312321
impl RequestSender {
@@ -454,6 +463,8 @@ impl RequestSender {
454463
connect_timeout,
455464
tls_client_config,
456465
override_connect_addr,
466+
concurrent_outbound_connections_semaphore: self
467+
.concurrent_outbound_connections_semaphore,
457468
},
458469
async move {
459470
if use_tls {
@@ -520,24 +531,38 @@ impl HttpClients {
520531
}
521532
}
522533

523-
// We must use task-local variables for these config options when using
524-
// `hyper_util::client::legacy::Client::request` because there's no way to plumb
525-
// them through as parameters. Moreover, if there's already a pooled connection
526-
// ready, we'll reuse that and ignore these options anyway.
527534
tokio::task_local! {
535+
/// The options used when establishing a new connection.
536+
///
537+
/// We must use task-local variables for these config options when using
538+
/// `hyper_util::client::legacy::Client::request` because there's no way to plumb
539+
/// them through as parameters. Moreover, if there's already a pooled connection
540+
/// ready, we'll reuse that and ignore these options anyway. After each connection
541+
/// is established, the options are dropped.
528542
static CONNECT_OPTIONS: ConnectOptions;
529543
}
530544

531545
#[derive(Clone)]
532546
struct ConnectOptions {
547+
/// The blocked networks configuration.
533548
blocked_networks: BlockedNetworks,
549+
/// Timeout for establishing a TCP connection.
534550
connect_timeout: Duration,
551+
/// TLS client configuration to use, if any.
535552
tls_client_config: Option<TlsClientConfig>,
553+
/// If set, override the address to connect to instead of using the given `uri`'s authority.
536554
override_connect_addr: Option<SocketAddr>,
555+
/// A semaphore to limit the number of concurrent outbound connections.
556+
concurrent_outbound_connections_semaphore: Option<Arc<Semaphore>>,
537557
}
538558

539559
impl ConnectOptions {
540-
async fn connect_tcp(&self, uri: &Uri, default_port: u16) -> Result<TcpStream, ErrorCode> {
560+
/// Establish a TCP connection to the given URI and default port.
561+
async fn connect_tcp(
562+
&self,
563+
uri: &Uri,
564+
default_port: u16,
565+
) -> Result<PermittedTcpStream, ErrorCode> {
541566
let mut socket_addrs = match self.override_connect_addr {
542567
Some(override_connect_addr) => vec![override_connect_addr],
543568
None => {
@@ -572,22 +597,33 @@ impl ConnectOptions {
572597
return Err(ErrorCode::DestinationIpProhibited);
573598
}
574599

575-
timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
600+
// If we're limiting concurrent outbound requests, acquire a permit
601+
let permit = match &self.concurrent_outbound_connections_semaphore {
602+
Some(s) => s.clone().acquire_owned().await.ok(),
603+
None => None,
604+
};
605+
606+
let stream = timeout(self.connect_timeout, TcpStream::connect(&*socket_addrs))
576607
.await
577608
.map_err(|_| ErrorCode::ConnectionTimeout)?
578609
.map_err(|err| match err.kind() {
579610
std::io::ErrorKind::AddrNotAvailable => {
580611
dns_error("address not available".into(), 0)
581612
}
582613
_ => ErrorCode::ConnectionRefused,
583-
})
614+
})?;
615+
Ok(PermittedTcpStream {
616+
inner: stream,
617+
_permit: permit,
618+
})
584619
}
585620

621+
/// Establish a TLS connection to the given URI and default port.
586622
async fn connect_tls(
587623
&self,
588624
uri: &Uri,
589625
default_port: u16,
590-
) -> Result<TlsStream<TcpStream>, ErrorCode> {
626+
) -> Result<TlsStream<PermittedTcpStream>, ErrorCode> {
591627
let tcp_stream = self.connect_tcp(uri, default_port).await?;
592628

593629
let mut tls_client_config = self.tls_client_config.as_deref().unwrap().clone();
@@ -597,7 +633,7 @@ impl ConnectOptions {
597633
let domain = rustls::pki_types::ServerName::try_from(uri.host().unwrap())
598634
.map_err(|e| {
599635
tracing::warn!("dns lookup error: {e:?}");
600-
dns_error("invalid dns name".to_string(), 0)
636+
dns_error("invalid dns name".into(), 0)
601637
})?
602638
.to_owned();
603639
connector.connect(domain, tcp_stream).await.map_err(|e| {
@@ -607,20 +643,22 @@ impl ConnectOptions {
607643
}
608644
}
609645

646+
/// A connector the uses `ConnectOptions`
610647
#[derive(Clone)]
611648
struct HttpConnector;
612649

613650
impl HttpConnector {
614-
async fn connect(uri: Uri) -> Result<TokioIo<TcpStream>, ErrorCode> {
651+
async fn connect(uri: Uri) -> Result<TokioIo<PermittedTcpStream>, ErrorCode> {
615652
let stream = CONNECT_OPTIONS.get().connect_tcp(&uri, 80).await?;
616653
Ok(TokioIo::new(stream))
617654
}
618655
}
619656

620657
impl Service<Uri> for HttpConnector {
621-
type Response = TokioIo<TcpStream>;
658+
type Response = TokioIo<PermittedTcpStream>;
622659
type Error = ErrorCode;
623-
type Future = Pin<Box<dyn Future<Output = Result<TokioIo<TcpStream>, ErrorCode>> + Send>>;
660+
type Future =
661+
Pin<Box<dyn Future<Output = Result<TokioIo<PermittedTcpStream>, ErrorCode>> + Send>>;
624662

625663
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
626664
Poll::Ready(Ok(()))
@@ -631,6 +669,7 @@ impl Service<Uri> for HttpConnector {
631669
}
632670
}
633671

672+
/// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
634673
#[derive(Clone)]
635674
struct HttpsConnector;
636675

@@ -655,7 +694,7 @@ impl Service<Uri> for HttpsConnector {
655694
}
656695
}
657696

658-
struct RustlsStream(TlsStream<TcpStream>);
697+
struct RustlsStream(TlsStream<PermittedTcpStream>);
659698

660699
impl Connection for RustlsStream {
661700
fn connected(&self) -> Connected {
@@ -710,6 +749,54 @@ impl AsyncWrite for RustlsStream {
710749
}
711750
}
712751

752+
/// A TCP stream that holds an optional permit indicating that it is allowed to exist.
753+
struct PermittedTcpStream {
754+
/// The wrapped TCP stream.
755+
inner: TcpStream,
756+
/// A permit indicating that this stream is allowed to exist.
757+
///
758+
/// When this stream is dropped, the permit is also dropped, allowing another
759+
/// connection to be established.
760+
_permit: Option<OwnedSemaphorePermit>,
761+
}
762+
763+
impl Connection for PermittedTcpStream {
764+
fn connected(&self) -> Connected {
765+
self.inner.connected()
766+
}
767+
}
768+
769+
impl AsyncRead for PermittedTcpStream {
770+
fn poll_read(
771+
self: Pin<&mut Self>,
772+
cx: &mut Context<'_>,
773+
buf: &mut ReadBuf<'_>,
774+
) -> Poll<std::io::Result<()>> {
775+
Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
776+
}
777+
}
778+
779+
impl AsyncWrite for PermittedTcpStream {
780+
fn poll_write(
781+
self: Pin<&mut Self>,
782+
cx: &mut Context<'_>,
783+
buf: &[u8],
784+
) -> Poll<Result<usize, std::io::Error>> {
785+
Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
786+
}
787+
788+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
789+
Pin::new(&mut self.get_mut().inner).poll_flush(cx)
790+
}
791+
792+
fn poll_shutdown(
793+
self: Pin<&mut Self>,
794+
cx: &mut Context<'_>,
795+
) -> Poll<Result<(), std::io::Error>> {
796+
Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
797+
}
798+
}
799+
713800
/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
714801
fn hyper_request_error(err: hyper::Error) -> ErrorCode {
715802
// If there's a source, we might be able to extract a wasi-http error from it.

0 commit comments

Comments
 (0)