Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(p2p): Add a maximum number of connections per IP address
Browse files Browse the repository at this point in the history
msbrogli committed Sep 23, 2023

Verified

This commit was signed with the committer’s verified signature.
chenrui333 Rui Chen
1 parent e2d9278 commit 617b327
Showing 5 changed files with 73 additions and 12 deletions.
2 changes: 2 additions & 0 deletions hathor/p2p/factory.py
Original file line number Diff line number Diff line change
@@ -57,6 +57,7 @@ def buildProtocol(self, addr: IAddress) -> MyServerProtocol:
p2p_manager=self.p2p_manager,
use_ssl=self.use_ssl,
inbound=True,
remote_address=addr,
)
p.factory = self
return p
@@ -90,6 +91,7 @@ def buildProtocol(self, addr: IAddress) -> MyClientProtocol:
p2p_manager=self.p2p_manager,
use_ssl=self.use_ssl,
inbound=False,
remote_address=addr,
)
p.factory = self
return p
14 changes: 14 additions & 0 deletions hathor/p2p/manager.py
Original file line number Diff line number Diff line change
@@ -125,6 +125,7 @@ def __init__(self,

# Global maximum number of connections.
self.max_connections: int = settings.PEER_MAX_CONNECTIONS
self.max_connections_per_ip: int = 16

# Global rate limiter for all connections.
self.rate_limiter = RateLimiter(self.reactor)
@@ -314,6 +315,19 @@ def on_peer_connect(self, protocol: HathorProtocol) -> None:
self.log.warn('reached maximum number of connections', max_connections=self.max_connections)
protocol.disconnect(force=True)
return

ip_address = protocol.get_remote_ip_address()
if ip_address:
count = len([1 for conn in self.connections if conn.get_remote_ip_address() == ip_address])
if count >= self.max_connections_per_ip:
self.log.warn(
'reached maximum number of connections per ip address',
ip_address=ip_address,
max_connections_per_ip=self.max_connections_per_ip,
)
protocol.disconnect(force=True)
return

self.connections.add(protocol)
self.handshaking_peers.add(protocol)

14 changes: 10 additions & 4 deletions hathor/p2p/protocol.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@

from structlog import get_logger
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IDelayedCall, ITCPTransport, ITransport
from twisted.internet.interfaces import IAddress, IDelayedCall, ITCPTransport, ITransport
from twisted.internet.protocol import connectionDone
from twisted.protocols.basic import LineReceiver
from twisted.python.failure import Failure
@@ -92,11 +92,12 @@ class WarningFlags(str, Enum):
capabilities: set[str] # capabilities received from the peer in HelloState

def __init__(self, network: str, my_peer: PeerId, p2p_manager: 'ConnectionsManager',
*, use_ssl: bool, inbound: bool) -> None:
*, use_ssl: bool, inbound: bool, remote_address: IAddress) -> None:
self._settings = get_settings()
self.network = network
self.my_peer = my_peer
self.connections = p2p_manager
self.remote_address = remote_address

assert p2p_manager.manager is not None
self.node = p2p_manager.manager
@@ -181,8 +182,11 @@ def is_state(self, state_enum: PeerState) -> bool:

def get_short_remote(self) -> str:
"""Get remote for logging."""
assert self.transport is not None
return format_address(self.transport.getPeer())
return format_address(self.remote_address)

def get_remote_ip_address(self) -> Optional[str]:
"""Return remote address (ipv4 or ipv6)."""
return getattr(self.remote_address, 'host', None)

def get_peer_id(self) -> Optional[str]:
"""Get peer id for logging."""
@@ -230,6 +234,8 @@ def on_connect(self) -> None:
""" Executed when the connection is established.
"""
assert not self.aborting
assert self.transport is not None
assert self.remote_address == self.transport.getPeer()
self.update_log_context()
self.log.debug('new connection')

29 changes: 21 additions & 8 deletions hathor/simulator/fake_connection.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,8 @@

from OpenSSL.crypto import X509
from structlog import get_logger
from twisted.internet.address import HostnameAddress
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IAddress
from twisted.internet.testing import StringTransport

if TYPE_CHECKING:
@@ -28,8 +29,8 @@


class HathorStringTransport(StringTransport):
def __init__(self, peer: 'PeerId'):
super().__init__()
def __init__(self, peer: 'PeerId', hostAddress: IAddress, peerAddress: IAddress) -> None:
super().__init__(hostAddress=hostAddress, peerAddress=peerAddress)
self.peer = peer

def getPeerCertificate(self) -> X509:
@@ -39,7 +40,8 @@ def getPeerCertificate(self) -> X509:

class FakeConnection:
def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, latency: float = 0,
autoreconnect: bool = False):
autoreconnect: bool = False, address1: Optional[IAddress] = None,
address2: Optional[IAddress] = None):
"""
:param: latency: Latency between nodes in seconds
"""
@@ -56,6 +58,9 @@ def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, late
self._buf1: deque[str] = deque()
self._buf2: deque[str] = deque()

self._address1: Optional[IAddress] = address1
self._address2: Optional[IAddress] = address2

self.reconnect()

@property
@@ -140,6 +145,10 @@ def can_step(self) -> bool:
return False

def run_one_step(self, debug=False, force=False):
if self.tr1.disconnected:
return
if self.tr2.disconnected:
return
assert self.is_connected, 'not connected'

if debug:
@@ -218,10 +227,14 @@ def reconnect(self) -> None:
self.disconnect(Failure(Exception('forced reconnection')))
self._buf1.clear()
self._buf2.clear()
self._proto1 = self.manager1.connections.server_factory.buildProtocol(HostnameAddress(b'fake', 0))
self._proto2 = self.manager2.connections.client_factory.buildProtocol(HostnameAddress(b'fake', 0))
self.tr1 = HathorStringTransport(self._proto2.my_peer)
self.tr2 = HathorStringTransport(self._proto1.my_peer)

address1 = self._address1 or IPv4Address('TCP', '192.168.0.14', 1234)
address2 = self._address2 or IPv4Address('TCP', '192.168.0.72', 5432)

self._proto1 = self.manager1.connections.server_factory.buildProtocol(address2)
self._proto2 = self.manager2.connections.client_factory.buildProtocol(address1)
self.tr1 = HathorStringTransport(self._proto2.my_peer, address1, address2)
self.tr2 = HathorStringTransport(self._proto1.my_peer, address2, address1)
self._proto1.makeConnection(self.tr1)
self._proto2.makeConnection(self.tr2)
self.is_connected = True
26 changes: 26 additions & 0 deletions tests/p2p/test_max_conn_per_ip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from twisted.internet.address import IPv4Address

from hathor.simulator import FakeConnection
from tests.simulation.base import SimulatorTestCase


class PeerRelayTestCase(SimulatorTestCase):
__test__ = True

def test_max_conn_per_ip(self) -> None:
m0 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)

max_connections_per_ip = m0.connections.max_connections_per_ip
for i in range(1, max_connections_per_ip + 8):
m1 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)

address = IPv4Address('TCP', '127.0.0.1', 1234 + i)
conn = FakeConnection(m0, m1, latency=0.05, address2=address)
self.simulator.add_connection(conn)

self.simulator.run(10)

if i <= max_connections_per_ip:
self.assertFalse(conn.tr1.disconnected)
else:
self.assertTrue(conn.tr1.disconnected)

0 comments on commit 617b327

Please sign in to comment.