Skip to content

Commit

Permalink
Merge commit from fork
Browse files Browse the repository at this point in the history
  • Loading branch information
finnbear authored Sep 2, 2024
1 parent c292a3c commit e01609c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 9 deletions.
30 changes: 21 additions & 9 deletions quinn-proto/src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ impl Endpoint {

if self.cids_exhausted() {
debug!("refusing connection");
self.index.remove_initial(incoming.orig_dst_cid);
self.index.remove_initial(dst_cid);
return Err(AcceptError {
cause: ConnectionError::CidsExhausted,
response: Some(self.initial_close(
Expand Down Expand Up @@ -602,7 +602,7 @@ impl Endpoint {
.is_err()
{
debug!(packet_number, "failed to authenticate initial packet");
self.index.remove_initial(incoming.orig_dst_cid);
self.index.remove_initial(dst_cid);
return Err(AcceptError {
cause: TransportError::PROTOCOL_VIOLATION("authentication failed").into(),
response: None,
Expand Down Expand Up @@ -651,9 +651,7 @@ impl Endpoint {
transport_config,
remote_address_validated,
);
if dst_cid.len() != 0 {
self.index.insert_initial(dst_cid, ch);
}
self.index.insert_initial(dst_cid, ch);

match conn.handle_first_packet(
now,
Expand Down Expand Up @@ -802,7 +800,7 @@ impl Endpoint {

/// Clean up endpoint data structures associated with an `Incoming`.
fn clean_up_incoming(&mut self, incoming: &Incoming) {
self.index.remove_initial(incoming.orig_dst_cid);
self.index.remove_initial(incoming.packet.header.dst_cid);
let incoming_buffer = self.incoming_buffers.remove(incoming.incoming_idx);
self.all_incoming_buffers_total_bytes -= incoming_buffer.total_bytes;
}
Expand Down Expand Up @@ -864,6 +862,7 @@ impl Endpoint {
cids_issued,
loc_cids,
addresses,
side,
reset_token: None,
});
debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
Expand Down Expand Up @@ -994,6 +993,8 @@ struct ConnectionIndex {
/// Identifies connections based on the initial DCID the peer utilized
///
/// Uses a standard `HashMap` to protect against hash collision attacks.
///
/// Used by the server, not the client.
connection_ids_initial: HashMap<ConnectionId, RouteDatagramTo>,
/// Identifies connections based on locally created CIDs
///
Expand Down Expand Up @@ -1022,17 +1023,27 @@ struct ConnectionIndex {
impl ConnectionIndex {
/// Associate an incoming connection with its initial destination CID
fn insert_initial_incoming(&mut self, dst_cid: ConnectionId, incoming_key: usize) {
if dst_cid.len() == 0 {
return;
}
self.connection_ids_initial
.insert(dst_cid, RouteDatagramTo::Incoming(incoming_key));
}

/// Remove an association with an initial destination CID
fn remove_initial(&mut self, dst_cid: ConnectionId) {
self.connection_ids_initial.remove(&dst_cid);
if dst_cid.len() == 0 {
return;
}
let removed = self.connection_ids_initial.remove(&dst_cid);
debug_assert!(removed.is_some());
}

/// Associate a connection with its initial destination CID
fn insert_initial(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
if dst_cid.len() == 0 {
return;
}
self.connection_ids_initial
.insert(dst_cid, RouteDatagramTo::Connection(connection));
}
Expand Down Expand Up @@ -1070,8 +1081,8 @@ impl ConnectionIndex {

/// Remove all references to a connection
fn remove(&mut self, conn: &ConnectionMeta) {
if conn.init_cid.len() > 0 {
self.connection_ids_initial.remove(&conn.init_cid);
if conn.side.is_server() {
self.remove_initial(conn.init_cid);
}
for cid in conn.loc_cids.values() {
self.connection_ids.remove(cid);
Expand Down Expand Up @@ -1126,6 +1137,7 @@ pub(crate) struct ConnectionMeta {
/// Only needed to support connections with zero-length CIDs, which cannot migrate, so we don't
/// bother keeping it up to date.
addresses: FourTuple,
side: Side,
/// Reset token provided by the peer for the CID we're currently sending to, and the address
/// being sent to
reset_token: Option<(SocketAddr, ResetToken)>,
Expand Down
26 changes: 26 additions & 0 deletions quinn-proto/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2994,6 +2994,32 @@ fn reject_manually() {
));
}

#[test]
fn validate_then_reject_manually() {
let _guard = subscribe();
let mut pair = Pair::default();
pair.server.incoming_connection_behavior = IncomingConnectionBehavior::ValidateThenReject;

// The server should now retry and reject incoming connections.
let client_ch = pair.begin_connect(client_config());
pair.drive();
pair.server.assert_no_accept();
let client = pair.client.connections.get_mut(&client_ch).unwrap();
assert!(client.is_closed());
assert!(matches!(
client.poll(),
Some(Event::ConnectionLost {
reason: ConnectionError::ConnectionClosed(close)
}) if close.error_code == TransportErrorCode::CONNECTION_REFUSED
));
pair.drive();
assert_matches!(pair.client_conn_mut(client_ch).poll(), None);
assert_eq!(pair.client.known_connections(), 0);
assert_eq!(pair.client.known_cids(), 0);
assert_eq!(pair.server.known_connections(), 0);
assert_eq!(pair.server.known_cids(), 0);
}

#[test]
fn endpoint_and_connection_impl_send_sync() {
const fn is_send_sync<T: Send + Sync>() {}
Expand Down
8 changes: 8 additions & 0 deletions quinn-proto/src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ pub(super) enum IncomingConnectionBehavior {
AcceptAll,
RejectAll,
Validate,
ValidateThenReject,
Wait,
}

Expand Down Expand Up @@ -377,6 +378,13 @@ impl TestEndpoint {
self.retry(incoming);
}
}
IncomingConnectionBehavior::ValidateThenReject => {
if incoming.remote_address_validated() {
self.reject(incoming);
} else {
self.retry(incoming);
}
}
IncomingConnectionBehavior::Wait => {
self.waiting_incoming.push(incoming);
}
Expand Down

0 comments on commit e01609c

Please sign in to comment.