Skip to content

Commit 6539908

Browse files
authored
refactor(s2n-quic-dc): configurable server_name option (#2775)
1 parent fa30e8a commit 6539908

File tree

4 files changed

+92
-27
lines changed

4 files changed

+92
-27
lines changed

dc/s2n-quic-dc/src/psk/client.rs

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use super::io::{self, HandshakeFailed};
55
use crate::path::secret;
66
use s2n_quic::{
77
provider::{event::Subscriber as Sub, tls::Provider as Prov},
8+
server::Name,
89
Connection,
910
};
1011
use std::{net::SocketAddr, sync::Arc, time::Duration};
@@ -101,6 +102,7 @@ impl Provider {
101102
subscriber: Subscriber,
102103
query_event_callback: fn(&mut Connection, Duration),
103104
builder: Builder<Event>,
105+
server_name: Name,
104106
) -> io::Result<Self> {
105107
let state = State::new_runtime(
106108
addr,
@@ -117,11 +119,13 @@ impl Provider {
117119
if let Some(state) = weak.upgrade() {
118120
let runtime = state.runtime.as_ref().map(|v| &v.0).unwrap();
119121
let client = state.client.clone();
122+
let server_name = server_name.clone();
120123
// Drop the JoinHandle -- we're not actually going to block on the join handle's
121124
// result. The future will keep running in the background.
122125
runtime.spawn(async move {
123-
if let Err(HandshakeFailed { .. }) =
124-
client.connect(peer, query_event_callback).await
126+
if let Err(HandshakeFailed { .. }) = client
127+
.connect(peer, query_event_callback, server_name)
128+
.await
125129
{
126130
// failure has already been logged, no further action required.
127131
}
@@ -140,9 +144,10 @@ impl Provider {
140144
&self,
141145
peer: SocketAddr,
142146
query_event_callback: fn(&mut Connection, Duration),
147+
server_name: Name,
143148
) -> std::io::Result<HandshakeKind> {
144149
let (_peer, kind) = self
145-
.handshake_with_entry(peer, query_event_callback)
150+
.handshake_with_entry(peer, query_event_callback, server_name)
146151
.await?;
147152
Ok(kind)
148153
}
@@ -156,11 +161,12 @@ impl Provider {
156161
&self,
157162
peer: SocketAddr,
158163
query_event_callback: fn(&mut Connection, Duration),
164+
server_name: Name,
159165
) -> std::io::Result<(secret::map::Peer, HandshakeKind)> {
160166
// Unconditionally request a background handshake. This schedules any re-handshaking
161167
// needed.
162168
if self.state.runtime.is_some() {
163-
let _ = self.background_handshake_with(peer, query_event_callback);
169+
let _ = self.background_handshake_with(peer, query_event_callback, server_name.clone());
164170
}
165171

166172
if let Some(peer) = self.state.map.get_tracked(peer) {
@@ -170,10 +176,18 @@ impl Provider {
170176
let state = self.state.clone();
171177
if let Some((runtime, _)) = self.state.runtime.as_ref() {
172178
runtime
173-
.spawn(async move { state.client.connect(peer, query_event_callback).await })
179+
.spawn(async move {
180+
state
181+
.client
182+
.connect(peer, query_event_callback, server_name)
183+
.await
184+
})
174185
.await??;
175186
} else {
176-
state.client.connect(peer, query_event_callback).await?;
187+
state
188+
.client
189+
.connect(peer, query_event_callback, server_name)
190+
.await?;
177191
}
178192

179193
// already recorded a metric above in get_tracked.
@@ -193,18 +207,21 @@ impl Provider {
193207
&self,
194208
peer: SocketAddr,
195209
query_event_callback: fn(&mut Connection, Duration),
210+
server_name: Name,
196211
) -> std::io::Result<HandshakeKind> {
197212
if self.state.map.contains(&peer) {
198213
return Ok(HandshakeKind::Cached);
199214
}
200215

201216
let client = self.state.client.clone();
202217
if let Some((runtime, _)) = self.state.runtime.as_ref() {
218+
let server_name = server_name.clone();
203219
// Drop the JoinHandle -- we're not actually going to block on the join handle's
204220
// result. The future will keep running in the background.
205221
runtime.spawn(async move {
206-
if let Err(HandshakeFailed { .. }) =
207-
client.connect(peer, query_event_callback).await
222+
if let Err(HandshakeFailed { .. }) = client
223+
.connect(peer, query_event_callback, server_name)
224+
.await
208225
{
209226
// error already logged
210227
}
@@ -229,18 +246,22 @@ impl Provider {
229246
&self,
230247
peer: SocketAddr,
231248
query_event_callback: fn(&mut Connection, Duration),
249+
server_name: Name,
232250
) -> std::io::Result<HandshakeKind> {
233251
// Unconditionally request a background handshake. This schedules any re-handshaking
234252
// needed.
235253
if self.state.runtime.is_some() {
236-
let _ = self.background_handshake_with(peer, query_event_callback);
254+
let _ = self.background_handshake_with(peer, query_event_callback, server_name.clone());
237255
}
238256

239257
if self.state.map.contains(&peer) {
240258
return Ok(HandshakeKind::Cached);
241259
}
242260

243-
let fut = self.state.client.connect(peer, query_event_callback);
261+
let fut = self
262+
.state
263+
.client
264+
.connect(peer, query_event_callback, server_name);
244265
if let Some((runtime, _)) = self.state.runtime.as_ref() {
245266
runtime.block_on(fut)?
246267
} else {
@@ -260,11 +281,17 @@ impl Provider {
260281
&self,
261282
peer: SocketAddr,
262283
query_event_callback: fn(&mut Connection, Duration),
284+
server_name: Name,
263285
) -> std::io::Result<secret::map::Peer> {
264286
let state = self.state.clone();
265287
if let Some((runtime, _)) = self.state.runtime.as_ref() {
266288
runtime
267-
.spawn(async move { state.client.connect(peer, query_event_callback).await })
289+
.spawn(async move {
290+
state
291+
.client
292+
.connect(peer, query_event_callback, server_name)
293+
.await
294+
})
268295
.await??;
269296
} else {
270297
return Err(std::io::Error::new(

dc/s2n-quic-dc/src/psk/client/builder.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::{
99
};
1010
use s2n_quic::{
1111
provider::{event::Subscriber as Sub, tls::Provider as Prov},
12+
server::Name,
1213
Connection,
1314
};
1415
use std::{net::SocketAddr, time::Duration};
@@ -102,6 +103,7 @@ impl<Event: s2n_quic::provider::event::Subscriber> Builder<Event> {
102103
tls_materials_provider: TlsProvider,
103104
subscriber: Subscriber,
104105
query_event_callback: fn(&mut Connection, Duration),
106+
server_name: Name,
105107
) -> Result<Provider> {
106108
Provider::new::<TlsProvider, Subscriber, Event>(
107109
addr,
@@ -110,6 +112,7 @@ impl<Event: s2n_quic::provider::event::Subscriber> Builder<Event> {
110112
subscriber,
111113
query_event_callback,
112114
self,
115+
server_name,
113116
)
114117
}
115118
}

dc/s2n-quic-dc/src/psk/io.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use s2n_quic::{
1010
event::Subscriber as Sub,
1111
tls::Provider as Prov,
1212
},
13+
server::Name,
1314
Connection,
1415
};
1516
use std::{
@@ -209,10 +210,11 @@ impl Client {
209210
&self,
210211
peer: SocketAddr,
211212
query_event_callback: fn(&mut Connection, Duration),
213+
server_name: Name,
212214
) -> Result<(), HandshakeFailed> {
213215
self.queue
214216
.clone()
215-
.handshake(&self.client, peer, query_event_callback)
217+
.handshake(&self.client, peer, query_event_callback, server_name)
216218
.await
217219
}
218220
}
@@ -314,6 +316,7 @@ impl HandshakeQueue {
314316
client: &s2n_quic::Client,
315317
peer: SocketAddr,
316318
query_event_callback: fn(&mut Connection, Duration),
319+
server_name: Name,
317320
) -> Result<(), HandshakeFailed> {
318321
let entry = self.allocate_entry(peer);
319322
let entry2 = entry.clone();
@@ -328,7 +331,7 @@ impl HandshakeQueue {
328331
let limiter_duration = start.elapsed();
329332

330333
let mut connection = client
331-
.connect(s2n_quic::client::Connect::new(peer).with_server_name("anyhostname"))
334+
.connect(s2n_quic::client::Connect::new(peer).with_server_name(server_name))
332335
.await?;
333336

334337
query_event_callback(&mut connection, limiter_duration);

dc/s2n-quic-dc/src/stream/client/tokio.rs

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use crate::{
1717
recv, socket,
1818
},
1919
};
20+
use s2n_quic::server::Name;
2021
use s2n_quic_core::time::Clock;
2122
use std::{io, net::SocketAddr, time::Duration};
2223
use tokio::net::TcpStream;
@@ -33,6 +34,7 @@ pub trait Handshake: Clone {
3334
async fn handshake_with_entry(
3435
&self,
3536
remote_handshake_addr: SocketAddr,
37+
server_name: Name,
3638
) -> std::io::Result<(secret::map::Peer, secret::HandshakeKind)>;
3739

3840
fn local_addr(&self) -> std::io::Result<SocketAddr>;
@@ -44,8 +46,9 @@ impl Handshake for crate::psk::client::Provider {
4446
async fn handshake_with_entry(
4547
&self,
4648
remote_handshake_addr: SocketAddr,
49+
server_name: Name,
4750
) -> std::io::Result<(secret::map::Peer, secret::HandshakeKind)> {
48-
self.handshake_with_entry(remote_handshake_addr, |_conn, _duration| {})
51+
self.handshake_with_entry(remote_handshake_addr, |_conn, _duration| {}, server_name)
4952
.await
5053
}
5154

@@ -89,10 +92,11 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
8992
pub async fn handshake_with(
9093
&self,
9194
remote_handshake_addr: SocketAddr,
95+
server_name: Name,
9296
) -> io::Result<secret::HandshakeKind> {
9397
let (_peer, kind) = self
9498
.handshake
95-
.handshake_with_entry(remote_handshake_addr)
99+
.handshake_with_entry(remote_handshake_addr, server_name)
96100
.await?;
97101
Ok(kind)
98102
}
@@ -101,10 +105,11 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
101105
async fn handshake_for_connect(
102106
&self,
103107
remote_handshake_addr: SocketAddr,
108+
server_name: Name,
104109
) -> io::Result<secret::map::Peer> {
105110
let (peer, _kind) = self
106111
.handshake
107-
.handshake_with_entry(remote_handshake_addr)
112+
.handshake_with_entry(remote_handshake_addr, server_name)
108113
.await?;
109114
Ok(peer)
110115
}
@@ -115,10 +120,17 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
115120
&self,
116121
handshake_addr: SocketAddr,
117122
acceptor_addr: SocketAddr,
123+
server_name: Name,
118124
) -> io::Result<Stream<S>> {
119125
match self.default_protocol {
120-
socket::Protocol::Udp => self.connect_udp(handshake_addr, acceptor_addr).await,
121-
socket::Protocol::Tcp => self.connect_tcp(handshake_addr, acceptor_addr).await,
126+
socket::Protocol::Udp => {
127+
self.connect_udp(handshake_addr, acceptor_addr, server_name)
128+
.await
129+
}
130+
socket::Protocol::Tcp => {
131+
self.connect_tcp(handshake_addr, acceptor_addr, server_name)
132+
.await
133+
}
122134
protocol => Err(io::Error::new(
123135
io::ErrorKind::InvalidInput,
124136
format!("invalid default protocol {protocol:?}"),
@@ -133,19 +145,32 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
133145
acceptor_addr: SocketAddr,
134146
request: Req,
135147
response: Res,
148+
server_name: Name,
136149
) -> io::Result<Res::Output>
137150
where
138151
Req: rpc::Request,
139152
Res: rpc::Response,
140153
{
141154
match self.default_protocol {
142155
socket::Protocol::Udp => {
143-
self.rpc_udp(handshake_addr, acceptor_addr, request, response)
144-
.await
156+
self.rpc_udp(
157+
handshake_addr,
158+
acceptor_addr,
159+
request,
160+
response,
161+
server_name,
162+
)
163+
.await
145164
}
146165
socket::Protocol::Tcp => {
147-
self.rpc_tcp(handshake_addr, acceptor_addr, request, response)
148-
.await
166+
self.rpc_tcp(
167+
handshake_addr,
168+
acceptor_addr,
169+
request,
170+
response,
171+
server_name,
172+
)
173+
.await
149174
}
150175
protocol => Err(io::Error::new(
151176
io::ErrorKind::InvalidInput,
@@ -160,9 +185,10 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
160185
&self,
161186
handshake_addr: SocketAddr,
162187
acceptor_addr: SocketAddr,
188+
server_name: Name,
163189
) -> io::Result<Stream<S>> {
164190
// ensure we have a secret for the peer
165-
let handshake = self.handshake_for_connect(handshake_addr);
191+
let handshake = self.handshake_for_connect(handshake_addr, server_name);
166192

167193
let mut stream = client::connect_udp(handshake, acceptor_addr, &self.env).await?;
168194
Self::write_prelude(&mut stream).await?;
@@ -177,13 +203,14 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
177203
acceptor_addr: SocketAddr,
178204
request: Req,
179205
response: Res,
206+
server_name: Name,
180207
) -> io::Result<Res::Output>
181208
where
182209
Req: rpc::Request,
183210
Res: rpc::Response,
184211
{
185212
// ensure we have a secret for the peer
186-
let handshake = self.handshake_for_connect(handshake_addr);
213+
let handshake = self.handshake_for_connect(handshake_addr, server_name);
187214

188215
let stream = client::connect_udp(handshake, acceptor_addr, &self.env).await?;
189216
rpc_internal::from_stream(stream, request, response).await
@@ -195,9 +222,10 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
195222
&self,
196223
handshake_addr: SocketAddr,
197224
acceptor_addr: SocketAddr,
225+
server_name: Name,
198226
) -> io::Result<Stream<S>> {
199227
// ensure we have a secret for the peer
200-
let handshake = self.handshake_for_connect(handshake_addr);
228+
let handshake = self.handshake_for_connect(handshake_addr, server_name);
201229

202230
let mut stream =
203231
client::connect_tcp(handshake, acceptor_addr, &self.env, self.linger).await?;
@@ -213,13 +241,14 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
213241
acceptor_addr: SocketAddr,
214242
request: Req,
215243
response: Res,
244+
server_name: Name,
216245
) -> io::Result<Res::Output>
217246
where
218247
Req: rpc::Request,
219248
Res: rpc::Response,
220249
{
221250
// ensure we have a secret for the peer
222-
let handshake = self.handshake_for_connect(handshake_addr);
251+
let handshake = self.handshake_for_connect(handshake_addr, server_name);
223252

224253
let stream = client::connect_tcp(handshake, acceptor_addr, &self.env, self.linger).await?;
225254
rpc_internal::from_stream(stream, request, response).await
@@ -231,9 +260,12 @@ impl<H: Handshake + Clone, S: event::Subscriber + Clone> Client<H, S> {
231260
&self,
232261
handshake_addr: SocketAddr,
233262
stream: TcpStream,
263+
server_name: Name,
234264
) -> io::Result<Stream<S>> {
235265
// ensure we have a secret for the peer
236-
let handshake = self.handshake_for_connect(handshake_addr).await?;
266+
let handshake = self
267+
.handshake_for_connect(handshake_addr, server_name)
268+
.await?;
237269

238270
let mut stream = client::connect_tcp_with(handshake, stream, &self.env).await?;
239271
Self::write_prelude(&mut stream).await?;

0 commit comments

Comments
 (0)