@@ -26,7 +26,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta
2626use tokio:: {
2727 io:: { AsyncRead , AsyncWrite , ReadBuf } ,
2828 net:: TcpStream ,
29- sync:: Semaphore ,
29+ sync:: { OwnedSemaphorePermit , Semaphore } ,
3030 time:: timeout,
3131} ;
3232use tokio_rustls:: client:: TlsStream ;
@@ -129,9 +129,9 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
129129 self_request_origin : self . state . self_request_origin . clone ( ) ,
130130 blocked_networks : self . state . blocked_networks . clone ( ) ,
131131 http_clients : self . state . wasi_http_clients . clone ( ) ,
132- concurrent_outbound_requests_semaphore : self
132+ concurrent_outbound_connections_semaphore : self
133133 . state
134- . concurrent_outbound_requests_semaphore
134+ . concurrent_outbound_connections_semaphore
135135 . clone ( ) ,
136136 } ;
137137 Ok ( HostFutureIncomingResponse :: Pending (
@@ -158,7 +158,7 @@ struct RequestSender {
158158 self_request_origin : Option < SelfRequestOrigin > ,
159159 request_interceptor : Option < Arc < dyn OutboundHttpInterceptor > > ,
160160 http_clients : HttpClients ,
161- concurrent_outbound_requests_semaphore : Option < Arc < Semaphore > > ,
161+ concurrent_outbound_connections_semaphore : Option < Arc < Semaphore > > ,
162162}
163163
164164impl RequestSender {
@@ -300,12 +300,18 @@ impl RequestSender {
300300 None
301301 } ;
302302
303+ // If we're limiting concurrent outbound requests, acquire a permit
304+ let permit = match & self . concurrent_outbound_connections_semaphore {
305+ Some ( s) => s. clone ( ) . acquire_owned ( ) . await . ok ( ) . map ( Arc :: new) ,
306+ None => None ,
307+ } ;
303308 let resp = CONNECT_OPTIONS . scope (
304309 ConnectOptions {
305310 blocked_networks : self . blocked_networks ,
306311 connect_timeout,
307312 tls_client_config,
308313 override_connect_addr,
314+ permit,
309315 } ,
310316 async move {
311317 if use_tls {
@@ -326,17 +332,11 @@ impl RequestSender {
326332 } ,
327333 ) ;
328334
329- // If we're limiting concurrent outbound requests, acquire a permit
330- let permit = match & self . concurrent_outbound_requests_semaphore {
331- Some ( s) => s. acquire ( ) . await . ok ( ) ,
332- None => None ,
333- } ;
334335 let resp = timeout ( first_byte_timeout, resp)
335336 . await
336337 . map_err ( |_| ErrorCode :: ConnectionReadTimeout ) ?
337338 . map_err ( hyper_legacy_request_error) ?
338339 . map ( |body| body. map_err ( hyper_request_error) . boxed ( ) ) ;
339- drop ( permit) ;
340340
341341 tracing:: Span :: current ( ) . record ( "http.response.status_code" , resp. status ( ) . as_u16 ( ) ) ;
342342
@@ -378,24 +378,40 @@ impl HttpClients {
378378 }
379379}
380380
381- // We must use task-local variables for these config options when using
382- // `hyper_util::client::legacy::Client::request` because there's no way to plumb
383- // them through as parameters. Moreover, if there's already a pooled connection
384- // ready, we'll reuse that and ignore these options anyway.
385381tokio:: task_local! {
382+ /// The options used when establishing a new connection.
383+ ///
384+ /// We must use task-local variables for these config options when using
385+ /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
386+ /// them through as parameters. Moreover, if there's already a pooled connection
387+ /// ready, we'll reuse that and ignore these options anyway. After each connection
388+ /// is established, the options are dropped.
386389 static CONNECT_OPTIONS : ConnectOptions ;
387390}
388391
389392#[ derive( Clone ) ]
390393struct ConnectOptions {
394+ /// The blocked networks configuration.
391395 blocked_networks : BlockedNetworks ,
396+ /// Timeout for establishing a TCP connection.
392397 connect_timeout : Duration ,
398+ /// TLS client configuration to use, if any.
393399 tls_client_config : Option < TlsClientConfig > ,
400+ /// If set, override the address to connect to instead of using the given `uri`'s authority.
394401 override_connect_addr : Option < SocketAddr > ,
402+ /// A permit for this connection
403+ ///
404+ /// If there is a permit, it should be dropped when the connection is closed.
405+ permit : Option < Arc < OwnedSemaphorePermit > > ,
395406}
396407
397408impl ConnectOptions {
398- async fn connect_tcp ( & self , uri : & Uri , default_port : u16 ) -> Result < TcpStream , ErrorCode > {
409+ /// Establish a TCP connection to the given URI and default port.
410+ async fn connect_tcp (
411+ & self ,
412+ uri : & Uri ,
413+ default_port : u16 ,
414+ ) -> Result < PermittedTcpStream , ErrorCode > {
399415 let mut socket_addrs = match self . override_connect_addr {
400416 Some ( override_connect_addr) => vec ! [ override_connect_addr] ,
401417 None => {
@@ -430,22 +446,27 @@ impl ConnectOptions {
430446 return Err ( ErrorCode :: DestinationIpProhibited ) ;
431447 }
432448
433- timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
449+ let stream = timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
434450 . await
435451 . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
436452 . map_err ( |err| match err. kind ( ) {
437453 std:: io:: ErrorKind :: AddrNotAvailable => {
438454 dns_error ( "address not available" . into ( ) , 0 )
439455 }
440456 _ => ErrorCode :: ConnectionRefused ,
441- } )
457+ } ) ?;
458+ Ok ( PermittedTcpStream {
459+ inner : stream,
460+ _permit : self . permit . clone ( ) ,
461+ } )
442462 }
443463
464+ /// Establish a TLS connection to the given URI and default port.
444465 async fn connect_tls (
445466 & self ,
446467 uri : & Uri ,
447468 default_port : u16 ,
448- ) -> Result < TlsStream < TcpStream > , ErrorCode > {
469+ ) -> Result < TlsStream < PermittedTcpStream > , ErrorCode > {
449470 let tcp_stream = self . connect_tcp ( uri, default_port) . await ?;
450471
451472 let mut tls_client_config = self . tls_client_config . as_deref ( ) . unwrap ( ) . clone ( ) ;
@@ -465,20 +486,22 @@ impl ConnectOptions {
465486 }
466487}
467488
489+ /// A connector the uses `ConnectOptions`
468490#[ derive( Clone ) ]
469491struct HttpConnector ;
470492
471493impl HttpConnector {
472- async fn connect ( uri : Uri ) -> Result < TokioIo < TcpStream > , ErrorCode > {
494+ async fn connect ( uri : Uri ) -> Result < TokioIo < PermittedTcpStream > , ErrorCode > {
473495 let stream = CONNECT_OPTIONS . get ( ) . connect_tcp ( & uri, 80 ) . await ?;
474496 Ok ( TokioIo :: new ( stream) )
475497 }
476498}
477499
478500impl Service < Uri > for HttpConnector {
479- type Response = TokioIo < TcpStream > ;
501+ type Response = TokioIo < PermittedTcpStream > ;
480502 type Error = ErrorCode ;
481- type Future = Pin < Box < dyn Future < Output = Result < TokioIo < TcpStream > , ErrorCode > > + Send > > ;
503+ type Future =
504+ Pin < Box < dyn Future < Output = Result < TokioIo < PermittedTcpStream > , ErrorCode > > + Send > > ;
482505
483506 fn poll_ready ( & mut self , _cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Self :: Error > > {
484507 Poll :: Ready ( Ok ( ( ) ) )
@@ -489,6 +512,7 @@ impl Service<Uri> for HttpConnector {
489512 }
490513}
491514
515+ /// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
492516#[ derive( Clone ) ]
493517struct HttpsConnector ;
494518
@@ -513,7 +537,7 @@ impl Service<Uri> for HttpsConnector {
513537 }
514538}
515539
516- struct RustlsStream ( TlsStream < TcpStream > ) ;
540+ struct RustlsStream ( TlsStream < PermittedTcpStream > ) ;
517541
518542impl Connection for RustlsStream {
519543 fn connected ( & self ) -> Connected {
@@ -568,6 +592,55 @@ impl AsyncWrite for RustlsStream {
568592 }
569593}
570594
595+ /// A TCP stream that holds an optional permit indicating that it is allowed to exist.
596+ struct PermittedTcpStream {
597+ inner : TcpStream ,
598+ _permit : Option < Arc < OwnedSemaphorePermit > > ,
599+ }
600+
601+ impl PermittedTcpStream {
602+ fn connected ( & self ) -> Connected {
603+ self . inner . connected ( )
604+ }
605+ }
606+
607+ impl Connection for PermittedTcpStream {
608+ fn connected ( & self ) -> Connected {
609+ self . inner . connected ( )
610+ }
611+ }
612+
613+ impl AsyncRead for PermittedTcpStream {
614+ fn poll_read (
615+ self : Pin < & mut Self > ,
616+ cx : & mut Context < ' _ > ,
617+ buf : & mut ReadBuf < ' _ > ,
618+ ) -> Poll < std:: io:: Result < ( ) > > {
619+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_read ( cx, buf)
620+ }
621+ }
622+
623+ impl AsyncWrite for PermittedTcpStream {
624+ fn poll_write (
625+ self : Pin < & mut Self > ,
626+ cx : & mut Context < ' _ > ,
627+ buf : & [ u8 ] ,
628+ ) -> Poll < Result < usize , std:: io:: Error > > {
629+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_write ( cx, buf)
630+ }
631+
632+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , std:: io:: Error > > {
633+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_flush ( cx)
634+ }
635+
636+ fn poll_shutdown (
637+ self : Pin < & mut Self > ,
638+ cx : & mut Context < ' _ > ,
639+ ) -> Poll < Result < ( ) , std:: io:: Error > > {
640+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_shutdown ( cx)
641+ }
642+ }
643+
571644/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
572645fn hyper_request_error ( err : hyper:: Error ) -> ErrorCode {
573646 // If there's a source, we might be able to extract a wasi-http error from it.
0 commit comments