diff --git a/src/transport/manager/limits.rs b/src/transport/manager/limits.rs index a6a1dfd9..0af49eb1 100644 --- a/src/transport/manager/limits.rs +++ b/src/transport/manager/limits.rs @@ -113,9 +113,10 @@ impl ConnectionLimits { } /// Called when a new connection is established. - pub fn on_connection_established( + /// + /// Returns an error if the connection cannot be accepted due to connection limits. + pub fn can_accept_connection( &mut self, - connection_id: ConnectionId, is_listener: bool, ) -> Result<(), ConnectionLimitsError> { // Check connection limits. @@ -131,7 +132,20 @@ impl ConnectionLimits { } } - // Keep track of the connection. + Ok(()) + } + + /// Accept an established connection. + /// + /// # Note + /// + /// This method should be called after the `Self::can_accept_connection` method + /// to ensure that the connection can be accepted. + pub fn accept_established_connection( + &mut self, + connection_id: ConnectionId, + is_listener: bool, + ) { if is_listener { if self.config.max_incoming_connections.is_some() { self.incoming_connections.insert(connection_id); @@ -139,8 +153,6 @@ impl ConnectionLimits { } else if self.config.max_outgoing_connections.is_some() { self.outgoing_connections.insert(connection_id); } - - Ok(()) } /// Called when a connection is closed. @@ -167,35 +179,39 @@ mod tests { let connection_id_out_1 = ConnectionId::random(); let connection_id_out_2 = ConnectionId::random(); let connection_id_in_3 = ConnectionId::random(); - let connection_id_out_3 = ConnectionId::random(); // Establish incoming connection. - assert!(limits.on_connection_established(connection_id_in_1, true).is_ok()); + assert!(limits.can_accept_connection(true).is_ok()); + limits.accept_established_connection(connection_id_in_1, true); assert_eq!(limits.incoming_connections.len(), 1); - assert!(limits.on_connection_established(connection_id_in_2, true).is_ok()); + assert!(limits.can_accept_connection(true).is_ok()); + limits.accept_established_connection(connection_id_in_2, true); assert_eq!(limits.incoming_connections.len(), 2); - assert!(limits.on_connection_established(connection_id_in_3, true).is_ok()); + assert!(limits.can_accept_connection(true).is_ok()); + limits.accept_established_connection(connection_id_in_3, true); assert_eq!(limits.incoming_connections.len(), 3); assert_eq!( - limits.on_connection_established(ConnectionId::random(), true).unwrap_err(), + limits.can_accept_connection(true).unwrap_err(), ConnectionLimitsError::MaxIncomingConnectionsExceeded ); assert_eq!(limits.incoming_connections.len(), 3); // Establish outgoing connection. - assert!(limits.on_connection_established(connection_id_out_1, false).is_ok()); + assert!(limits.can_accept_connection(false).is_ok()); + limits.accept_established_connection(connection_id_out_1, false); assert_eq!(limits.incoming_connections.len(), 3); assert_eq!(limits.outgoing_connections.len(), 1); - assert!(limits.on_connection_established(connection_id_out_2, false).is_ok()); + assert!(limits.can_accept_connection(false).is_ok()); + limits.accept_established_connection(connection_id_out_2, false); assert_eq!(limits.incoming_connections.len(), 3); assert_eq!(limits.outgoing_connections.len(), 2); assert_eq!( - limits.on_connection_established(connection_id_out_3, false).unwrap_err(), + limits.can_accept_connection(false).unwrap_err(), ConnectionLimitsError::MaxOutgoingConnectionsExceeded ); diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index 12c963d7..d7eba036 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -774,10 +774,7 @@ impl TransportManager { }; // Reject the connection if exceeded limits. - if let Err(error) = self - .connection_limits - .on_connection_established(endpoint.connection_id(), endpoint.is_listener()) - { + if let Err(error) = self.connection_limits.can_accept_connection(endpoint.is_listener()) { tracing::debug!( target: LOG_TARGET, ?peer, @@ -806,6 +803,9 @@ impl TransportManager { ); if connection_accepted { + self.connection_limits + .accept_established_connection(endpoint.connection_id(), endpoint.is_listener()); + // Cancel all pending dials if the connection was established. if let PeerState::Opening { connection_id,