@@ -8,19 +8,331 @@ use crate::{
88 path:: secret,
99 stream:: {
1010 application:: Stream ,
11+ client:: { rpc as rpc_internal, tokio as client} ,
1112 endpoint,
1213 environment:: {
1314 tokio:: { self as env, Environment } ,
14- Environment as _,
15+ udp as udp_pool , Environment as _,
1516 } ,
16- recv,
17- socket:: Protocol ,
17+ recv, socket,
1818 } ,
1919} ;
2020use s2n_quic_core:: time:: Clock ;
2121use std:: { io, net:: SocketAddr , time:: Duration } ;
2222use tokio:: net:: TcpStream ;
2323
24+ pub mod rpc {
25+ pub use crate :: stream:: client:: rpc:: { InMemoryResponse , Request , Response } ;
26+ }
27+
28+ // This trait is a temporary solution to abstract handshake_with_entry
29+ // and local_addr until we implement the handshake provider
30+ #[ allow( async_fn_in_trait) ]
31+ pub trait Handshake : AsRef < secret:: Map > + Clone {
32+ /// Handshake with the remote peer
33+ async fn handshake_with_entry (
34+ & self ,
35+ remote_handshake_addr : SocketAddr ,
36+ ) -> std:: io:: Result < ( secret:: map:: Peer , secret:: HandshakeKind ) > ;
37+
38+ fn local_addr ( & self ) -> std:: io:: Result < SocketAddr > ;
39+ }
40+
41+ #[ derive( Clone ) ]
42+ pub struct Client < H : Handshake + Clone , S : event:: Subscriber + Clone > {
43+ env : Environment < S > ,
44+ handshake : H ,
45+ default_protocol : socket:: Protocol ,
46+ linger : Option < Duration > ,
47+ }
48+
49+ impl < H : Handshake + Clone , S : event:: Subscriber + Clone > Client < H , S > {
50+ #[ inline]
51+ pub fn new ( handshake : H , subscriber : S ) -> io:: Result < Self > {
52+ Self :: builder ( ) . build ( handshake, subscriber)
53+ }
54+
55+ #[ inline]
56+ pub fn builder ( ) -> Builder {
57+ Builder :: default ( )
58+ }
59+
60+ pub fn drop_state ( & self ) {
61+ self . handshake . as_ref ( ) . drop_state ( )
62+ }
63+
64+ pub fn handshake_state ( & self ) -> & H {
65+ & self . handshake
66+ }
67+
68+ #[ inline]
69+ pub async fn handshake_with (
70+ & self ,
71+ remote_handshake_addr : SocketAddr ,
72+ ) -> io:: Result < secret:: HandshakeKind > {
73+ let ( _peer, kind) = self
74+ . handshake
75+ . handshake_with_entry ( remote_handshake_addr)
76+ . await ?;
77+ Ok ( kind)
78+ }
79+
80+ #[ inline]
81+ async fn handshake_for_connect (
82+ & self ,
83+ remote_handshake_addr : SocketAddr ,
84+ ) -> io:: Result < secret:: map:: Peer > {
85+ let ( peer, _kind) = self
86+ . handshake
87+ . handshake_with_entry ( remote_handshake_addr)
88+ . await ?;
89+ Ok ( peer)
90+ }
91+
92+ /// Connects using the preferred protocol
93+ #[ inline]
94+ pub async fn connect (
95+ & self ,
96+ handshake_addr : SocketAddr ,
97+ acceptor_addr : SocketAddr ,
98+ ) -> io:: Result < Stream < S > > {
99+ match self . default_protocol {
100+ socket:: Protocol :: Udp => self . connect_udp ( handshake_addr, acceptor_addr) . await ,
101+ socket:: Protocol :: Tcp => self . connect_tcp ( handshake_addr, acceptor_addr) . await ,
102+ protocol => Err ( io:: Error :: new (
103+ io:: ErrorKind :: InvalidInput ,
104+ format ! ( "invalid default protocol {protocol:?}" ) ,
105+ ) ) ,
106+ }
107+ }
108+
109+ /// Makes an RPC request using the preferred protocol
110+ pub async fn rpc < Req , Res > (
111+ & self ,
112+ handshake_addr : SocketAddr ,
113+ acceptor_addr : SocketAddr ,
114+ request : Req ,
115+ response : Res ,
116+ ) -> io:: Result < Res :: Output >
117+ where
118+ Req : rpc:: Request ,
119+ Res : rpc:: Response ,
120+ {
121+ match self . default_protocol {
122+ socket:: Protocol :: Udp => {
123+ self . rpc_udp ( handshake_addr, acceptor_addr, request, response)
124+ . await
125+ }
126+ socket:: Protocol :: Tcp => {
127+ self . rpc_tcp ( handshake_addr, acceptor_addr, request, response)
128+ . await
129+ }
130+ protocol => Err ( io:: Error :: new (
131+ io:: ErrorKind :: InvalidInput ,
132+ format ! ( "invalid default protocol {protocol:?}" ) ,
133+ ) ) ,
134+ }
135+ }
136+
137+ /// Connects using the UDP transport layer
138+ #[ inline]
139+ pub async fn connect_udp (
140+ & self ,
141+ handshake_addr : SocketAddr ,
142+ acceptor_addr : SocketAddr ,
143+ ) -> io:: Result < Stream < S > > {
144+ // ensure we have a secret for the peer
145+ let handshake = self . handshake_for_connect ( handshake_addr) ;
146+
147+ let mut stream = client:: connect_udp ( handshake, acceptor_addr, & self . env ) . await ?;
148+ Self :: write_prelude ( & mut stream) . await ?;
149+ Ok ( stream)
150+ }
151+
152+ /// Makes an RPC request using the UDP transport layer
153+ #[ inline]
154+ pub async fn rpc_udp < Req , Res > (
155+ & self ,
156+ handshake_addr : SocketAddr ,
157+ acceptor_addr : SocketAddr ,
158+ request : Req ,
159+ response : Res ,
160+ ) -> io:: Result < Res :: Output >
161+ where
162+ Req : rpc:: Request ,
163+ Res : rpc:: Response ,
164+ {
165+ // ensure we have a secret for the peer
166+ let handshake = self . handshake_for_connect ( handshake_addr) ;
167+
168+ let stream = client:: connect_udp ( handshake, acceptor_addr, & self . env ) . await ?;
169+ rpc_internal:: from_stream ( stream, request, response) . await
170+ }
171+
172+ /// Connects using the TCP transport layer
173+ #[ inline]
174+ pub async fn connect_tcp (
175+ & self ,
176+ handshake_addr : SocketAddr ,
177+ acceptor_addr : SocketAddr ,
178+ ) -> io:: Result < Stream < S > > {
179+ // ensure we have a secret for the peer
180+ let handshake = self . handshake_for_connect ( handshake_addr) ;
181+
182+ let mut stream =
183+ client:: connect_tcp ( handshake, acceptor_addr, & self . env , self . linger ) . await ?;
184+ Self :: write_prelude ( & mut stream) . await ?;
185+ Ok ( stream)
186+ }
187+
188+ /// Makes an RPC request using the TCP transport layer
189+ #[ inline]
190+ pub async fn rpc_tcp < Req , Res > (
191+ & self ,
192+ handshake_addr : SocketAddr ,
193+ acceptor_addr : SocketAddr ,
194+ request : Req ,
195+ response : Res ,
196+ ) -> io:: Result < Res :: Output >
197+ where
198+ Req : rpc:: Request ,
199+ Res : rpc:: Response ,
200+ {
201+ // ensure we have a secret for the peer
202+ let handshake = self . handshake_for_connect ( handshake_addr) ;
203+
204+ let stream = client:: connect_tcp ( handshake, acceptor_addr, & self . env , self . linger ) . await ?;
205+ rpc_internal:: from_stream ( stream, request, response) . await
206+ }
207+
208+ /// Connects with a pre-existing TCP stream
209+ #[ inline]
210+ pub async fn connect_tcp_with (
211+ & self ,
212+ handshake_addr : SocketAddr ,
213+ stream : TcpStream ,
214+ ) -> io:: Result < Stream < S > > {
215+ // ensure we have a secret for the peer
216+ let handshake = self . handshake_for_connect ( handshake_addr) . await ?;
217+
218+ let mut stream = client:: connect_tcp_with ( handshake, stream, & self . env ) . await ?;
219+ Self :: write_prelude ( & mut stream) . await ?;
220+ Ok ( stream)
221+ }
222+
223+ #[ inline]
224+ async fn write_prelude ( stream : & mut Stream < S > ) -> io:: Result < ( ) > {
225+ // TODO should we actually write the prelude here or should we do late sealer binding on
226+ // the first packet to reduce secret reordering on the peer
227+
228+ stream
229+ . write_from ( & mut s2n_quic_core:: buffer:: reader:: storage:: Empty )
230+ . await
231+ . map ( |_| ( ) )
232+ }
233+ }
234+
235+ #[ derive( Default ) ]
236+ pub struct Builder {
237+ default_protocol : Option < socket:: Protocol > ,
238+ background_threads : Option < usize > ,
239+ linger : Option < Duration > ,
240+ send_buffer : Option < usize > ,
241+ recv_buffer : Option < usize > ,
242+ }
243+
244+ impl Builder {
245+ pub fn with_tcp ( self , enabled : bool ) -> Self {
246+ self . with_default_protocol ( if enabled {
247+ socket:: Protocol :: Tcp
248+ } else {
249+ socket:: Protocol :: Udp
250+ } )
251+ }
252+
253+ pub fn with_udp ( self , enabled : bool ) -> Self {
254+ self . with_default_protocol ( if enabled {
255+ socket:: Protocol :: Udp
256+ } else {
257+ socket:: Protocol :: Tcp
258+ } )
259+ }
260+
261+ pub fn with_default_protocol ( mut self , protocol : socket:: Protocol ) -> Self {
262+ self . default_protocol = Some ( protocol) ;
263+ self
264+ }
265+
266+ pub fn with_background_threads ( mut self , threads : usize ) -> Self {
267+ self . background_threads = Some ( threads) ;
268+ self
269+ }
270+
271+ pub fn with_linger ( mut self , linger : Duration ) -> Self {
272+ self . linger = Some ( linger) ;
273+ self
274+ }
275+
276+ /// Sets the send buffer for the OS socket handle.
277+ ///
278+ /// See `SO_SNDBUF` for more information.
279+ ///
280+ /// Note that this only applies to sockets that are created by SaltyLib. Any sockets
281+ /// provided by the application will not inherit this value.
282+ pub fn with_send_buffer ( mut self , bytes : usize ) -> Self {
283+ self . send_buffer = Some ( bytes) ;
284+ self
285+ }
286+
287+ /// Sets the recv buffer for the OS socket handle.
288+ ///
289+ /// See `SO_RCVBUF` for more information.
290+ ///
291+ /// Note that this only applies to sockets that are created by SaltyLib. Any sockets
292+ /// provided by the application will not inherit this value.
293+ pub fn with_recv_buffer ( mut self , bytes : usize ) -> Self {
294+ self . recv_buffer = Some ( bytes) ;
295+ self
296+ }
297+
298+ #[ inline]
299+ pub fn build < H : Handshake + Clone , S : event:: Subscriber + Clone > (
300+ self ,
301+ handshake : H ,
302+ subscriber : S ,
303+ ) -> io:: Result < Client < H , S > > {
304+ // bind the sockets to the same address family as the handshake
305+ let mut local_addr = handshake. local_addr ( ) ?;
306+ local_addr. set_port ( 0 ) ;
307+ let mut options = socket:: Options :: new ( local_addr) ;
308+
309+ options. send_buffer = self . send_buffer ;
310+ options. recv_buffer = self . recv_buffer ;
311+
312+ let mut env = env:: Builder :: new ( subscriber) . with_socket_options ( options) ;
313+
314+ let pool = udp_pool:: Config :: new ( ( handshake. as_ref ( ) ) . clone ( ) ) ;
315+ env = env. with_pool ( pool) ;
316+
317+ if let Some ( threads) = self . background_threads {
318+ env = env. with_threads ( threads) ;
319+ }
320+ let env = env. build ( ) ?;
321+
322+ // default to UDP
323+ let default_protocol = self . default_protocol . unwrap_or ( socket:: Protocol :: Udp ) ;
324+
325+ let linger = self . linger ;
326+
327+ Ok ( Client {
328+ env,
329+ handshake,
330+ default_protocol,
331+ linger,
332+ } )
333+ }
334+ }
335+
24336/// Connects using the UDP transport layer
25337///
26338/// Callers should send data immediately after calling this to ensure minimal
52364 // build the stream inside the application context
53365 let stream = stream. connect ( ) ?;
54366
55- debug_assert_eq ! ( stream. protocol( ) , Protocol :: Udp ) ;
367+ debug_assert_eq ! ( stream. protocol( ) , socket :: Protocol :: Udp ) ;
56368
57369 Ok ( stream)
58370}
@@ -190,7 +502,7 @@ where
190502 // build the stream inside the application context
191503 let stream = stream. connect ( ) ?;
192504
193- debug_assert_eq ! ( stream. protocol( ) , Protocol :: Tcp ) ;
505+ debug_assert_eq ! ( stream. protocol( ) , socket :: Protocol :: Tcp ) ;
194506
195507 Ok ( stream)
196508}
@@ -229,7 +541,7 @@ where
229541 // build the stream inside the application context
230542 let stream = stream. connect ( ) ?;
231543
232- debug_assert_eq ! ( stream. protocol( ) , Protocol :: Tcp ) ;
544+ debug_assert_eq ! ( stream. protocol( ) , socket :: Protocol :: Tcp ) ;
233545
234546 Ok ( stream)
235547}
0 commit comments