@@ -28,6 +28,7 @@ use spin_factors::{wasmtime::component::ResourceTable, RuntimeFactorsInstanceSta
2828use tokio:: {
2929 io:: { AsyncRead , AsyncWrite , ReadBuf } ,
3030 net:: TcpStream ,
31+ sync:: { OwnedSemaphorePermit , Semaphore } ,
3132 time:: timeout,
3233} ;
3334use tokio_rustls:: client:: TlsStream ;
@@ -90,6 +91,9 @@ impl p3::WasiHttpCtx for InstanceState {
9091 self_request_origin : self . self_request_origin . clone ( ) ,
9192 blocked_networks : self . blocked_networks . clone ( ) ,
9293 http_clients : self . wasi_http_clients . clone ( ) ,
94+ concurrent_outbound_connections_semaphore : self
95+ . concurrent_outbound_connections_semaphore
96+ . clone ( ) ,
9397 } ;
9498 let config = OutgoingRequestConfig {
9599 use_tls : request. uri ( ) . scheme ( ) == Some ( & Scheme :: HTTPS ) ,
@@ -282,6 +286,10 @@ impl WasiHttpView for WasiHttpImplInner<'_> {
282286 self_request_origin : self . state . self_request_origin . clone ( ) ,
283287 blocked_networks : self . state . blocked_networks . clone ( ) ,
284288 http_clients : self . state . wasi_http_clients . clone ( ) ,
289+ concurrent_outbound_connections_semaphore : self
290+ . state
291+ . concurrent_outbound_connections_semaphore
292+ . clone ( ) ,
285293 } ;
286294 Ok ( HostFutureIncomingResponse :: Pending (
287295 wasmtime_wasi:: runtime:: spawn (
@@ -307,6 +315,7 @@ struct RequestSender {
307315 self_request_origin : Option < SelfRequestOrigin > ,
308316 request_interceptor : Option < Arc < dyn OutboundHttpInterceptor > > ,
309317 http_clients : HttpClients ,
318+ concurrent_outbound_connections_semaphore : Option < Arc < Semaphore > > ,
310319}
311320
312321impl RequestSender {
@@ -454,6 +463,8 @@ impl RequestSender {
454463 connect_timeout,
455464 tls_client_config,
456465 override_connect_addr,
466+ concurrent_outbound_connections_semaphore : self
467+ . concurrent_outbound_connections_semaphore ,
457468 } ,
458469 async move {
459470 if use_tls {
@@ -520,24 +531,38 @@ impl HttpClients {
520531 }
521532}
522533
523- // We must use task-local variables for these config options when using
524- // `hyper_util::client::legacy::Client::request` because there's no way to plumb
525- // them through as parameters. Moreover, if there's already a pooled connection
526- // ready, we'll reuse that and ignore these options anyway.
527534tokio:: task_local! {
535+ /// The options used when establishing a new connection.
536+ ///
537+ /// We must use task-local variables for these config options when using
538+ /// `hyper_util::client::legacy::Client::request` because there's no way to plumb
539+ /// them through as parameters. Moreover, if there's already a pooled connection
540+ /// ready, we'll reuse that and ignore these options anyway. After each connection
541+ /// is established, the options are dropped.
528542 static CONNECT_OPTIONS : ConnectOptions ;
529543}
530544
531545#[ derive( Clone ) ]
532546struct ConnectOptions {
547+ /// The blocked networks configuration.
533548 blocked_networks : BlockedNetworks ,
549+ /// Timeout for establishing a TCP connection.
534550 connect_timeout : Duration ,
551+ /// TLS client configuration to use, if any.
535552 tls_client_config : Option < TlsClientConfig > ,
553+ /// If set, override the address to connect to instead of using the given `uri`'s authority.
536554 override_connect_addr : Option < SocketAddr > ,
555+ /// A semaphore to limit the number of concurrent outbound connections.
556+ concurrent_outbound_connections_semaphore : Option < Arc < Semaphore > > ,
537557}
538558
539559impl ConnectOptions {
540- async fn connect_tcp ( & self , uri : & Uri , default_port : u16 ) -> Result < TcpStream , ErrorCode > {
560+ /// Establish a TCP connection to the given URI and default port.
561+ async fn connect_tcp (
562+ & self ,
563+ uri : & Uri ,
564+ default_port : u16 ,
565+ ) -> Result < PermittedTcpStream , ErrorCode > {
541566 let mut socket_addrs = match self . override_connect_addr {
542567 Some ( override_connect_addr) => vec ! [ override_connect_addr] ,
543568 None => {
@@ -572,22 +597,33 @@ impl ConnectOptions {
572597 return Err ( ErrorCode :: DestinationIpProhibited ) ;
573598 }
574599
575- timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
600+ // If we're limiting concurrent outbound requests, acquire a permit
601+ let permit = match & self . concurrent_outbound_connections_semaphore {
602+ Some ( s) => s. clone ( ) . acquire_owned ( ) . await . ok ( ) ,
603+ None => None ,
604+ } ;
605+
606+ let stream = timeout ( self . connect_timeout , TcpStream :: connect ( & * socket_addrs) )
576607 . await
577608 . map_err ( |_| ErrorCode :: ConnectionTimeout ) ?
578609 . map_err ( |err| match err. kind ( ) {
579610 std:: io:: ErrorKind :: AddrNotAvailable => {
580611 dns_error ( "address not available" . into ( ) , 0 )
581612 }
582613 _ => ErrorCode :: ConnectionRefused ,
583- } )
614+ } ) ?;
615+ Ok ( PermittedTcpStream {
616+ inner : stream,
617+ _permit : permit,
618+ } )
584619 }
585620
621+ /// Establish a TLS connection to the given URI and default port.
586622 async fn connect_tls (
587623 & self ,
588624 uri : & Uri ,
589625 default_port : u16 ,
590- ) -> Result < TlsStream < TcpStream > , ErrorCode > {
626+ ) -> Result < TlsStream < PermittedTcpStream > , ErrorCode > {
591627 let tcp_stream = self . connect_tcp ( uri, default_port) . await ?;
592628
593629 let mut tls_client_config = self . tls_client_config . as_deref ( ) . unwrap ( ) . clone ( ) ;
@@ -597,7 +633,7 @@ impl ConnectOptions {
597633 let domain = rustls:: pki_types:: ServerName :: try_from ( uri. host ( ) . unwrap ( ) )
598634 . map_err ( |e| {
599635 tracing:: warn!( "dns lookup error: {e:?}" ) ;
600- dns_error ( "invalid dns name" . to_string ( ) , 0 )
636+ dns_error ( "invalid dns name" . into ( ) , 0 )
601637 } ) ?
602638 . to_owned ( ) ;
603639 connector. connect ( domain, tcp_stream) . await . map_err ( |e| {
@@ -607,20 +643,22 @@ impl ConnectOptions {
607643 }
608644}
609645
646+ /// A connector the uses `ConnectOptions`
610647#[ derive( Clone ) ]
611648struct HttpConnector ;
612649
613650impl HttpConnector {
614- async fn connect ( uri : Uri ) -> Result < TokioIo < TcpStream > , ErrorCode > {
651+ async fn connect ( uri : Uri ) -> Result < TokioIo < PermittedTcpStream > , ErrorCode > {
615652 let stream = CONNECT_OPTIONS . get ( ) . connect_tcp ( & uri, 80 ) . await ?;
616653 Ok ( TokioIo :: new ( stream) )
617654 }
618655}
619656
620657impl Service < Uri > for HttpConnector {
621- type Response = TokioIo < TcpStream > ;
658+ type Response = TokioIo < PermittedTcpStream > ;
622659 type Error = ErrorCode ;
623- type Future = Pin < Box < dyn Future < Output = Result < TokioIo < TcpStream > , ErrorCode > > + Send > > ;
660+ type Future =
661+ Pin < Box < dyn Future < Output = Result < TokioIo < PermittedTcpStream > , ErrorCode > > + Send > > ;
624662
625663 fn poll_ready ( & mut self , _cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , Self :: Error > > {
626664 Poll :: Ready ( Ok ( ( ) ) )
@@ -631,6 +669,7 @@ impl Service<Uri> for HttpConnector {
631669 }
632670}
633671
672+ /// A connector that establishes TLS connections using `rustls` and `ConnectOptions`.
634673#[ derive( Clone ) ]
635674struct HttpsConnector ;
636675
@@ -655,7 +694,7 @@ impl Service<Uri> for HttpsConnector {
655694 }
656695}
657696
658- struct RustlsStream ( TlsStream < TcpStream > ) ;
697+ struct RustlsStream ( TlsStream < PermittedTcpStream > ) ;
659698
660699impl Connection for RustlsStream {
661700 fn connected ( & self ) -> Connected {
@@ -710,6 +749,54 @@ impl AsyncWrite for RustlsStream {
710749 }
711750}
712751
752+ /// A TCP stream that holds an optional permit indicating that it is allowed to exist.
753+ struct PermittedTcpStream {
754+ /// The wrapped TCP stream.
755+ inner : TcpStream ,
756+ /// A permit indicating that this stream is allowed to exist.
757+ ///
758+ /// When this stream is dropped, the permit is also dropped, allowing another
759+ /// connection to be established.
760+ _permit : Option < OwnedSemaphorePermit > ,
761+ }
762+
763+ impl Connection for PermittedTcpStream {
764+ fn connected ( & self ) -> Connected {
765+ self . inner . connected ( )
766+ }
767+ }
768+
769+ impl AsyncRead for PermittedTcpStream {
770+ fn poll_read (
771+ self : Pin < & mut Self > ,
772+ cx : & mut Context < ' _ > ,
773+ buf : & mut ReadBuf < ' _ > ,
774+ ) -> Poll < std:: io:: Result < ( ) > > {
775+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_read ( cx, buf)
776+ }
777+ }
778+
779+ impl AsyncWrite for PermittedTcpStream {
780+ fn poll_write (
781+ self : Pin < & mut Self > ,
782+ cx : & mut Context < ' _ > ,
783+ buf : & [ u8 ] ,
784+ ) -> Poll < Result < usize , std:: io:: Error > > {
785+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_write ( cx, buf)
786+ }
787+
788+ fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , std:: io:: Error > > {
789+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_flush ( cx)
790+ }
791+
792+ fn poll_shutdown (
793+ self : Pin < & mut Self > ,
794+ cx : & mut Context < ' _ > ,
795+ ) -> Poll < Result < ( ) , std:: io:: Error > > {
796+ Pin :: new ( & mut self . get_mut ( ) . inner ) . poll_shutdown ( cx)
797+ }
798+ }
799+
713800/// Translate a [`hyper::Error`] to a wasi-http `ErrorCode` in the context of a request.
714801fn hyper_request_error ( err : hyper:: Error ) -> ErrorCode {
715802 // If there's a source, we might be able to extract a wasi-http error from it.
0 commit comments