Skip to content

Commit 3a7aba3

Browse files
authored
feat(s2n-quic-dc): implement dcQUIC client Tokio Builder (#2741)
1 parent cfff80e commit 3a7aba3

File tree

1 file changed

+318
-6
lines changed
  • dc/s2n-quic-dc/src/stream/client

1 file changed

+318
-6
lines changed

dc/s2n-quic-dc/src/stream/client/tokio.rs

Lines changed: 318 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,331 @@ use crate::{
88
path::secret,
99
stream::{
1010
application::Stream,
11+
client::{rpc as rpc_internal, tokio as client},
1112
endpoint,
1213
environment::{
1314
tokio::{self as env, Environment},
14-
Environment as _,
15+
udp as udp_pool, Environment as _,
1516
},
16-
recv,
17-
socket::Protocol,
17+
recv, socket,
1818
},
1919
};
2020
use s2n_quic_core::time::Clock;
2121
use std::{io, net::SocketAddr, time::Duration};
2222
use tokio::net::TcpStream;
2323

24+
pub mod rpc {
25+
pub use crate::stream::client::rpc::{InMemoryResponse, Request, Response};
26+
}
27+
28+
// This trait is a temporary solution to abstract handshake_with_entry
29+
// and local_addr until we implement the handshake provider
30+
#[allow(async_fn_in_trait)]
31+
pub trait Handshake: AsRef<secret::Map> + Clone {
32+
/// Handshake with the remote peer
33+
async fn handshake_with_entry(
34+
&self,
35+
remote_handshake_addr: SocketAddr,
36+
) -> std::io::Result<(secret::map::Peer, secret::HandshakeKind)>;
37+
38+
fn local_addr(&self) -> std::io::Result<SocketAddr>;
39+
}
40+
41+
#[derive(Clone)]
42+
pub struct Client<H: Handshake + Clone, S: event::Subscriber + Clone> {
43+
env: Environment<S>,
44+
handshake: H,
45+
default_protocol: socket::Protocol,
46+
linger: Option<Duration>,
47+
}
48+
49+
impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
50+
#[inline]
51+
pub fn new(handshake: H, subscriber: S) -> io::Result<Self> {
52+
Self::builder().build(handshake, subscriber)
53+
}
54+
55+
#[inline]
56+
pub fn builder() -> Builder {
57+
Builder::default()
58+
}
59+
60+
pub fn drop_state(&self) {
61+
self.handshake.as_ref().drop_state()
62+
}
63+
64+
pub fn handshake_state(&self) -> &H {
65+
&self.handshake
66+
}
67+
68+
#[inline]
69+
pub async fn handshake_with(
70+
&self,
71+
remote_handshake_addr: SocketAddr,
72+
) -> io::Result<secret::HandshakeKind> {
73+
let (_peer, kind) = self
74+
.handshake
75+
.handshake_with_entry(remote_handshake_addr)
76+
.await?;
77+
Ok(kind)
78+
}
79+
80+
#[inline]
81+
async fn handshake_for_connect(
82+
&self,
83+
remote_handshake_addr: SocketAddr,
84+
) -> io::Result<secret::map::Peer> {
85+
let (peer, _kind) = self
86+
.handshake
87+
.handshake_with_entry(remote_handshake_addr)
88+
.await?;
89+
Ok(peer)
90+
}
91+
92+
/// Connects using the preferred protocol
93+
#[inline]
94+
pub async fn connect(
95+
&self,
96+
handshake_addr: SocketAddr,
97+
acceptor_addr: SocketAddr,
98+
) -> io::Result<Stream<S>> {
99+
match self.default_protocol {
100+
socket::Protocol::Udp => self.connect_udp(handshake_addr, acceptor_addr).await,
101+
socket::Protocol::Tcp => self.connect_tcp(handshake_addr, acceptor_addr).await,
102+
protocol => Err(io::Error::new(
103+
io::ErrorKind::InvalidInput,
104+
format!("invalid default protocol {protocol:?}"),
105+
)),
106+
}
107+
}
108+
109+
/// Makes an RPC request using the preferred protocol
110+
pub async fn rpc<Req, Res>(
111+
&self,
112+
handshake_addr: SocketAddr,
113+
acceptor_addr: SocketAddr,
114+
request: Req,
115+
response: Res,
116+
) -> io::Result<Res::Output>
117+
where
118+
Req: rpc::Request,
119+
Res: rpc::Response,
120+
{
121+
match self.default_protocol {
122+
socket::Protocol::Udp => {
123+
self.rpc_udp(handshake_addr, acceptor_addr, request, response)
124+
.await
125+
}
126+
socket::Protocol::Tcp => {
127+
self.rpc_tcp(handshake_addr, acceptor_addr, request, response)
128+
.await
129+
}
130+
protocol => Err(io::Error::new(
131+
io::ErrorKind::InvalidInput,
132+
format!("invalid default protocol {protocol:?}"),
133+
)),
134+
}
135+
}
136+
137+
/// Connects using the UDP transport layer
138+
#[inline]
139+
pub async fn connect_udp(
140+
&self,
141+
handshake_addr: SocketAddr,
142+
acceptor_addr: SocketAddr,
143+
) -> io::Result<Stream<S>> {
144+
// ensure we have a secret for the peer
145+
let handshake = self.handshake_for_connect(handshake_addr);
146+
147+
let mut stream = client::connect_udp(handshake, acceptor_addr, &self.env).await?;
148+
Self::write_prelude(&mut stream).await?;
149+
Ok(stream)
150+
}
151+
152+
/// Makes an RPC request using the UDP transport layer
153+
#[inline]
154+
pub async fn rpc_udp<Req, Res>(
155+
&self,
156+
handshake_addr: SocketAddr,
157+
acceptor_addr: SocketAddr,
158+
request: Req,
159+
response: Res,
160+
) -> io::Result<Res::Output>
161+
where
162+
Req: rpc::Request,
163+
Res: rpc::Response,
164+
{
165+
// ensure we have a secret for the peer
166+
let handshake = self.handshake_for_connect(handshake_addr);
167+
168+
let stream = client::connect_udp(handshake, acceptor_addr, &self.env).await?;
169+
rpc_internal::from_stream(stream, request, response).await
170+
}
171+
172+
/// Connects using the TCP transport layer
173+
#[inline]
174+
pub async fn connect_tcp(
175+
&self,
176+
handshake_addr: SocketAddr,
177+
acceptor_addr: SocketAddr,
178+
) -> io::Result<Stream<S>> {
179+
// ensure we have a secret for the peer
180+
let handshake = self.handshake_for_connect(handshake_addr);
181+
182+
let mut stream =
183+
client::connect_tcp(handshake, acceptor_addr, &self.env, self.linger).await?;
184+
Self::write_prelude(&mut stream).await?;
185+
Ok(stream)
186+
}
187+
188+
/// Makes an RPC request using the TCP transport layer
189+
#[inline]
190+
pub async fn rpc_tcp<Req, Res>(
191+
&self,
192+
handshake_addr: SocketAddr,
193+
acceptor_addr: SocketAddr,
194+
request: Req,
195+
response: Res,
196+
) -> io::Result<Res::Output>
197+
where
198+
Req: rpc::Request,
199+
Res: rpc::Response,
200+
{
201+
// ensure we have a secret for the peer
202+
let handshake = self.handshake_for_connect(handshake_addr);
203+
204+
let stream = client::connect_tcp(handshake, acceptor_addr, &self.env, self.linger).await?;
205+
rpc_internal::from_stream(stream, request, response).await
206+
}
207+
208+
/// Connects with a pre-existing TCP stream
209+
#[inline]
210+
pub async fn connect_tcp_with(
211+
&self,
212+
handshake_addr: SocketAddr,
213+
stream: TcpStream,
214+
) -> io::Result<Stream<S>> {
215+
// ensure we have a secret for the peer
216+
let handshake = self.handshake_for_connect(handshake_addr).await?;
217+
218+
let mut stream = client::connect_tcp_with(handshake, stream, &self.env).await?;
219+
Self::write_prelude(&mut stream).await?;
220+
Ok(stream)
221+
}
222+
223+
#[inline]
224+
async fn write_prelude(stream: &mut Stream<S>) -> io::Result<()> {
225+
// TODO should we actually write the prelude here or should we do late sealer binding on
226+
// the first packet to reduce secret reordering on the peer
227+
228+
stream
229+
.write_from(&mut s2n_quic_core::buffer::reader::storage::Empty)
230+
.await
231+
.map(|_| ())
232+
}
233+
}
234+
235+
#[derive(Default)]
236+
pub struct Builder {
237+
default_protocol: Option<socket::Protocol>,
238+
background_threads: Option<usize>,
239+
linger: Option<Duration>,
240+
send_buffer: Option<usize>,
241+
recv_buffer: Option<usize>,
242+
}
243+
244+
impl Builder {
245+
pub fn with_tcp(self, enabled: bool) -> Self {
246+
self.with_default_protocol(if enabled {
247+
socket::Protocol::Tcp
248+
} else {
249+
socket::Protocol::Udp
250+
})
251+
}
252+
253+
pub fn with_udp(self, enabled: bool) -> Self {
254+
self.with_default_protocol(if enabled {
255+
socket::Protocol::Udp
256+
} else {
257+
socket::Protocol::Tcp
258+
})
259+
}
260+
261+
pub fn with_default_protocol(mut self, protocol: socket::Protocol) -> Self {
262+
self.default_protocol = Some(protocol);
263+
self
264+
}
265+
266+
pub fn with_background_threads(mut self, threads: usize) -> Self {
267+
self.background_threads = Some(threads);
268+
self
269+
}
270+
271+
pub fn with_linger(mut self, linger: Duration) -> Self {
272+
self.linger = Some(linger);
273+
self
274+
}
275+
276+
/// Sets the send buffer for the OS socket handle.
277+
///
278+
/// See `SO_SNDBUF` for more information.
279+
///
280+
/// Note that this only applies to sockets that are created by SaltyLib. Any sockets
281+
/// provided by the application will not inherit this value.
282+
pub fn with_send_buffer(mut self, bytes: usize) -> Self {
283+
self.send_buffer = Some(bytes);
284+
self
285+
}
286+
287+
/// Sets the recv buffer for the OS socket handle.
288+
///
289+
/// See `SO_RCVBUF` for more information.
290+
///
291+
/// Note that this only applies to sockets that are created by SaltyLib. Any sockets
292+
/// provided by the application will not inherit this value.
293+
pub fn with_recv_buffer(mut self, bytes: usize) -> Self {
294+
self.recv_buffer = Some(bytes);
295+
self
296+
}
297+
298+
#[inline]
299+
pub fn build<H: Handshake + Clone, S: event::Subscriber + Clone>(
300+
self,
301+
handshake: H,
302+
subscriber: S,
303+
) -> io::Result<Client<H, S>> {
304+
// bind the sockets to the same address family as the handshake
305+
let mut local_addr = handshake.local_addr()?;
306+
local_addr.set_port(0);
307+
let mut options = socket::Options::new(local_addr);
308+
309+
options.send_buffer = self.send_buffer;
310+
options.recv_buffer = self.recv_buffer;
311+
312+
let mut env = env::Builder::new(subscriber).with_socket_options(options);
313+
314+
let pool = udp_pool::Config::new((handshake.as_ref()).clone());
315+
env = env.with_pool(pool);
316+
317+
if let Some(threads) = self.background_threads {
318+
env = env.with_threads(threads);
319+
}
320+
let env = env.build()?;
321+
322+
// default to UDP
323+
let default_protocol = self.default_protocol.unwrap_or(socket::Protocol::Udp);
324+
325+
let linger = self.linger;
326+
327+
Ok(Client {
328+
env,
329+
handshake,
330+
default_protocol,
331+
linger,
332+
})
333+
}
334+
}
335+
24336
/// Connects using the UDP transport layer
25337
///
26338
/// Callers should send data immediately after calling this to ensure minimal
@@ -52,7 +364,7 @@ where
52364
// build the stream inside the application context
53365
let stream = stream.connect()?;
54366

55-
debug_assert_eq!(stream.protocol(), Protocol::Udp);
367+
debug_assert_eq!(stream.protocol(), socket::Protocol::Udp);
56368

57369
Ok(stream)
58370
}
@@ -190,7 +502,7 @@ where
190502
// build the stream inside the application context
191503
let stream = stream.connect()?;
192504

193-
debug_assert_eq!(stream.protocol(), Protocol::Tcp);
505+
debug_assert_eq!(stream.protocol(), socket::Protocol::Tcp);
194506

195507
Ok(stream)
196508
}
@@ -229,7 +541,7 @@ where
229541
// build the stream inside the application context
230542
let stream = stream.connect()?;
231543

232-
debug_assert_eq!(stream.protocol(), Protocol::Tcp);
544+
debug_assert_eq!(stream.protocol(), socket::Protocol::Tcp);
233545

234546
Ok(stream)
235547
}

0 commit comments

Comments
 (0)