Skip to content

Commit

Permalink
feat(p2p): Add a maximum number of connections per IP address
Browse files Browse the repository at this point in the history
  • Loading branch information
msbrogli committed Sep 23, 2023
1 parent e2d9278 commit 617b327
Showing 5 changed files with 73 additions and 12 deletions.
2 changes: 2 additions & 0 deletions hathor/p2p/factory.py
Original file line number Diff line number Diff line change
@@ -57,6 +57,7 @@ def buildProtocol(self, addr: IAddress) -> MyServerProtocol:
p2p_manager=self.p2p_manager,
use_ssl=self.use_ssl,
inbound=True,
remote_address=addr,
)
p.factory = self
return p
@@ -90,6 +91,7 @@ def buildProtocol(self, addr: IAddress) -> MyClientProtocol:
p2p_manager=self.p2p_manager,
use_ssl=self.use_ssl,
inbound=False,
remote_address=addr,
)
p.factory = self
return p
14 changes: 14 additions & 0 deletions hathor/p2p/manager.py
Original file line number Diff line number Diff line change
@@ -125,6 +125,7 @@ def __init__(self,

# Global maximum number of connections.
self.max_connections: int = settings.PEER_MAX_CONNECTIONS
self.max_connections_per_ip: int = 16

# Global rate limiter for all connections.
self.rate_limiter = RateLimiter(self.reactor)
@@ -314,6 +315,19 @@ def on_peer_connect(self, protocol: HathorProtocol) -> None:
self.log.warn('reached maximum number of connections', max_connections=self.max_connections)
protocol.disconnect(force=True)
return

ip_address = protocol.get_remote_ip_address()
if ip_address:
count = len([1 for conn in self.connections if conn.get_remote_ip_address() == ip_address])
if count >= self.max_connections_per_ip:
self.log.warn(
'reached maximum number of connections per ip address',
ip_address=ip_address,
max_connections_per_ip=self.max_connections_per_ip,
)
protocol.disconnect(force=True)
return

self.connections.add(protocol)
self.handshaking_peers.add(protocol)

14 changes: 10 additions & 4 deletions hathor/p2p/protocol.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,7 @@

from structlog import get_logger
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IDelayedCall, ITCPTransport, ITransport
from twisted.internet.interfaces import IAddress, IDelayedCall, ITCPTransport, ITransport
from twisted.internet.protocol import connectionDone
from twisted.protocols.basic import LineReceiver
from twisted.python.failure import Failure
@@ -92,11 +92,12 @@ class WarningFlags(str, Enum):
capabilities: set[str] # capabilities received from the peer in HelloState

def __init__(self, network: str, my_peer: PeerId, p2p_manager: 'ConnectionsManager',
*, use_ssl: bool, inbound: bool) -> None:
*, use_ssl: bool, inbound: bool, remote_address: IAddress) -> None:
self._settings = get_settings()
self.network = network
self.my_peer = my_peer
self.connections = p2p_manager
self.remote_address = remote_address

assert p2p_manager.manager is not None
self.node = p2p_manager.manager
@@ -181,8 +182,11 @@ def is_state(self, state_enum: PeerState) -> bool:

def get_short_remote(self) -> str:
"""Get remote for logging."""
assert self.transport is not None
return format_address(self.transport.getPeer())
return format_address(self.remote_address)

def get_remote_ip_address(self) -> Optional[str]:
"""Return remote address (ipv4 or ipv6)."""
return getattr(self.remote_address, 'host', None)

def get_peer_id(self) -> Optional[str]:
"""Get peer id for logging."""
@@ -230,6 +234,8 @@ def on_connect(self) -> None:
""" Executed when the connection is established.
"""
assert not self.aborting
assert self.transport is not None
assert self.remote_address == self.transport.getPeer()
self.update_log_context()
self.log.debug('new connection')

29 changes: 21 additions & 8 deletions hathor/simulator/fake_connection.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,8 @@

from OpenSSL.crypto import X509
from structlog import get_logger
from twisted.internet.address import HostnameAddress
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IAddress
from twisted.internet.testing import StringTransport

if TYPE_CHECKING:
@@ -28,8 +29,8 @@


class HathorStringTransport(StringTransport):
def __init__(self, peer: 'PeerId'):
super().__init__()
def __init__(self, peer: 'PeerId', hostAddress: IAddress, peerAddress: IAddress) -> None:
super().__init__(hostAddress=hostAddress, peerAddress=peerAddress)
self.peer = peer

def getPeerCertificate(self) -> X509:
@@ -39,7 +40,8 @@ def getPeerCertificate(self) -> X509:

class FakeConnection:
def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, latency: float = 0,
autoreconnect: bool = False):
autoreconnect: bool = False, address1: Optional[IAddress] = None,
address2: Optional[IAddress] = None):
"""
:param: latency: Latency between nodes in seconds
"""
@@ -56,6 +58,9 @@ def __init__(self, manager1: 'HathorManager', manager2: 'HathorManager', *, late
self._buf1: deque[str] = deque()
self._buf2: deque[str] = deque()

self._address1: Optional[IAddress] = address1
self._address2: Optional[IAddress] = address2

self.reconnect()

@property
@@ -140,6 +145,10 @@ def can_step(self) -> bool:
return False

def run_one_step(self, debug=False, force=False):
if self.tr1.disconnected:
return
if self.tr2.disconnected:
return
assert self.is_connected, 'not connected'

if debug:
@@ -218,10 +227,14 @@ def reconnect(self) -> None:
self.disconnect(Failure(Exception('forced reconnection')))
self._buf1.clear()
self._buf2.clear()
self._proto1 = self.manager1.connections.server_factory.buildProtocol(HostnameAddress(b'fake', 0))
self._proto2 = self.manager2.connections.client_factory.buildProtocol(HostnameAddress(b'fake', 0))
self.tr1 = HathorStringTransport(self._proto2.my_peer)
self.tr2 = HathorStringTransport(self._proto1.my_peer)

address1 = self._address1 or IPv4Address('TCP', '192.168.0.14', 1234)
address2 = self._address2 or IPv4Address('TCP', '192.168.0.72', 5432)

self._proto1 = self.manager1.connections.server_factory.buildProtocol(address2)
self._proto2 = self.manager2.connections.client_factory.buildProtocol(address1)
self.tr1 = HathorStringTransport(self._proto2.my_peer, address1, address2)
self.tr2 = HathorStringTransport(self._proto1.my_peer, address2, address1)
self._proto1.makeConnection(self.tr1)
self._proto2.makeConnection(self.tr2)
self.is_connected = True
26 changes: 26 additions & 0 deletions tests/p2p/test_max_conn_per_ip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from twisted.internet.address import IPv4Address

from hathor.simulator import FakeConnection
from tests.simulation.base import SimulatorTestCase


class PeerRelayTestCase(SimulatorTestCase):
__test__ = True

def test_max_conn_per_ip(self) -> None:
m0 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)

max_connections_per_ip = m0.connections.max_connections_per_ip
for i in range(1, max_connections_per_ip + 8):
m1 = self.create_peer(enable_sync_v1=False, enable_sync_v2=True)

address = IPv4Address('TCP', '127.0.0.1', 1234 + i)
conn = FakeConnection(m0, m1, latency=0.05, address2=address)
self.simulator.add_connection(conn)

self.simulator.run(10)

if i <= max_connections_per_ip:
self.assertFalse(conn.tr1.disconnected)
else:
self.assertTrue(conn.tr1.disconnected)

0 comments on commit 617b327

Please sign in to comment.