Skip to content

Commit

Permalink
Peer.disconnect() now awaits for cancel()
Browse files Browse the repository at this point in the history
Also disconnect from remotes if we get unexpected NodeData or Receipts
msgs during a sync
  • Loading branch information
gsalgado committed Jul 10, 2018
1 parent b73bd9d commit a01347e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
9 changes: 6 additions & 3 deletions p2p/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from p2p.cancel_token import CancellableMixin, CancelToken
from p2p.constants import MAX_REORG_DEPTH
from p2p.exceptions import NoEligiblePeers, OperationCancelled
from p2p.p2p_proto import DisconnectReason
from p2p.peer import BasePeer, ETHPeer, LESPeer, PeerPool, PeerPoolSubscriber
from p2p.rlp import BlockBody
from p2p.service import BaseService
Expand Down Expand Up @@ -187,7 +188,7 @@ async def _sync(self, peer: HeaderRequestingPeer) -> None:
headers = await self._fetch_missing_headers(peer, start_at)
except TimeoutError:
self.logger.warn("Timeout waiting for header batch from %s, aborting sync", peer)
await peer.cancel()
await peer.disconnect(DisconnectReason.timeout)
break

if not headers:
Expand Down Expand Up @@ -509,7 +510,8 @@ async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
elif isinstance(cmd, eth.NodeData):
# When doing a chain sync we never send GetNodeData requests, so peers should not send
# us NodeData msgs.
self.logger.warn("Unexpected NodeData msg from %s", peer)
self.logger.warn("Unexpected NodeData msg from %s, disconnecting", peer)
await peer.disconnect(DisconnectReason.bad_protocol)
else:
self.logger.debug("%s msg not handled yet, need to be implemented", cmd)

Expand Down Expand Up @@ -606,7 +608,8 @@ async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
async def _handle_block_receipts(
self, peer: ETHPeer, receipts_by_block: List[List[eth.Receipt]]) -> None:
# When doing a regular sync we never request receipts.
self.logger.warn("Unexpected BlockReceipts msg from %s", peer)
self.logger.warn("Unexpected BlockReceipts msg from %s, disconnecting", peer)
await peer.disconnect(DisconnectReason.bad_protocol)

async def _process_headers(
self, peer: HeaderRequestingPeer, headers: Tuple[BlockHeader, ...]) -> int:
Expand Down
33 changes: 17 additions & 16 deletions p2p/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ async def do_p2p_handshake(self) -> None:
# Peers sometimes send a disconnect msg before they send the initial P2P handshake.
raise HandshakeFailure("{} disconnected before completing handshake: {}".format(
self, msg['reason_name']))
self.process_p2p_handshake(cmd, msg)
await self.process_p2p_handshake(cmd, msg)

@property
async def genesis(self) -> BlockHeader:
Expand Down Expand Up @@ -393,16 +393,17 @@ def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> N
else:
self.handle_sub_proto_msg(cmd, msg)

def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
async def process_p2p_handshake(
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
msg = cast(Dict[str, Any], msg)
if not isinstance(cmd, Hello):
self.disconnect(DisconnectReason.bad_protocol)
await self.disconnect(DisconnectReason.bad_protocol)
raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd))
remote_capabilities = msg['capabilities']
try:
self.sub_proto = self.select_sub_protocol(remote_capabilities)
except NoMatchingPeerCapabilities:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"No matching capabilities between us ({}) and {} ({}), disconnecting".format(
self.capabilities, self.remote, remote_capabilities))
Expand Down Expand Up @@ -474,9 +475,11 @@ def send(self, header: bytes, body: bytes) -> None:
self.logger.trace("Sending msg with cmd_id: %s", cmd_id)
self.writer.write(self.encrypt(header, body))

def disconnect(self, reason: DisconnectReason) -> None:
async def disconnect(self, reason: DisconnectReason) -> None:
"""Send a disconnect msg to the remote node and stop this Peer.
Also awaits for self.cancel() to ensure any pending tasks are cleaned up.
:param reason: An item from the DisconnectReason enum.
"""
if not isinstance(reason, DisconnectReason):
Expand All @@ -485,6 +488,8 @@ def disconnect(self, reason: DisconnectReason) -> None:
self.logger.debug("Disconnecting from remote peer; reason: %s", reason.name)
self.base_protocol.send_disconnect(reason.value)
self.close()
if self.is_running:
await self.cancel()

def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]]
) -> protocol.Protocol:
Expand Down Expand Up @@ -537,18 +542,18 @@ async def send_sub_proto_handshake(self) -> None:
async def process_sub_proto_handshake(
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
if not isinstance(cmd, (les.Status, les.StatusV2)):
self.disconnect(DisconnectReason.subprotocol_error)
await self.disconnect(DisconnectReason.subprotocol_error)
raise HandshakeFailure(
"Expected a LES Status msg, got {}, disconnecting".format(cmd))
msg = cast(Dict[str, Any], msg)
if msg['networkId'] != self.network_id:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"{} network ({}) does not match ours ({}), disconnecting".format(
self, msg['networkId'], self.network_id))
genesis = await self.genesis
if msg['genesisHash'] != genesis.hash:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"{} genesis ({}) does not match ours ({}), disconnecting".format(
self, encode_hex(msg['genesisHash']), genesis.hex_hash))
Expand Down Expand Up @@ -628,18 +633,18 @@ async def send_sub_proto_handshake(self) -> None:
async def process_sub_proto_handshake(
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
if not isinstance(cmd, eth.Status):
self.disconnect(DisconnectReason.subprotocol_error)
await self.disconnect(DisconnectReason.subprotocol_error)
raise HandshakeFailure(
"Expected a ETH Status msg, got {}, disconnecting".format(cmd))
msg = cast(Dict[str, Any], msg)
if msg['network_id'] != self.network_id:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"{} network ({}) does not match ours ({}), disconnecting".format(
self, msg['network_id'], self.network_id))
genesis = await self.genesis
if msg['genesis_hash'] != genesis.hash:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"{} genesis ({}) does not match ours ({}), disconnecting".format(
self, encode_hex(msg['genesis_hash']), genesis.hex_hash))
Expand Down Expand Up @@ -770,12 +775,8 @@ async def _run(self) -> None:

async def stop_all_peers(self) -> None:
self.logger.info("Stopping all peers ...")

peers = self.connected_nodes.values()
for peer in peers:
peer.disconnect(DisconnectReason.client_quitting)

await asyncio.gather(*[peer.cancel() for peer in peers])
await asyncio.gather(*[peer.disconnect(DisconnectReason.client_quitting) for peer in peers])

async def _cleanup(self) -> None:
await self.stop_all_peers()
Expand Down
2 changes: 1 addition & 1 deletion p2p/sharding_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def process_sub_proto_handshake(self,
cmd: Command,
msg: protocol._DecodedMsgType) -> None:
if not isinstance(cmd, Status):
self.disconnect(DisconnectReason.subprotocol_error)
await self.disconnect(DisconnectReason.subprotocol_error)
raise HandshakeFailure("Expected status msg, got {}, disconnecting".format(cmd))

async def _get_headers_at_chain_split(
Expand Down

0 comments on commit a01347e

Please sign in to comment.