Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FastChainSyncer now handles data requests from peers #993

Merged
merged 2 commits into from
Jul 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 39 additions & 30 deletions p2p/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from p2p.cancel_token import CancellableMixin, CancelToken
from p2p.constants import MAX_REORG_DEPTH
from p2p.exceptions import NoEligiblePeers, OperationCancelled
from p2p.p2p_proto import DisconnectReason
from p2p.peer import BasePeer, ETHPeer, LESPeer, PeerPool, PeerPoolSubscriber
from p2p.rlp import BlockBody
from p2p.service import BaseService
Expand Down Expand Up @@ -187,7 +188,7 @@ async def _sync(self, peer: HeaderRequestingPeer) -> None:
headers = await self._fetch_missing_headers(peer, start_at)
except TimeoutError:
self.logger.warn("Timeout waiting for header batch from %s, aborting sync", peer)
await peer.cancel()
await peer.disconnect(DisconnectReason.timeout)
break

if not headers:
Expand Down Expand Up @@ -289,6 +290,7 @@ async def _handle_get_block_headers(self, peer: LESPeer, msg: Dict[str, Any]) ->
query = msg['query']
headers = await self._handler.lookup_headers(
query.block_number_or_hash, query.max_headers, query.skip, query.reverse)
self.logger.trace("Replying to %s with %d headers", peer, len(headers))
peer.sub_proto.send_block_headers(headers, buffer_value=0, request_id=msg['request_id'])

async def _process_headers(
Expand Down Expand Up @@ -490,8 +492,28 @@ async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
await self._handle_new_block(peer, cast(Dict[str, Any], msg))
elif isinstance(cmd, eth.GetBlockHeaders):
await self._handle_get_block_headers(peer, cast(Dict[str, Any], msg))
elif isinstance(cmd, eth.GetBlockBodies):
# Only serve up to eth.MAX_BODIES_FETCH items in every request.
block_hashes = cast(List[Hash32], msg)[:eth.MAX_BODIES_FETCH]
await self._handler.handle_get_block_bodies(peer, block_hashes)
elif isinstance(cmd, eth.GetReceipts):
# Only serve up to eth.MAX_RECEIPTS_FETCH items in every request.
block_hashes = cast(List[Hash32], msg)[:eth.MAX_RECEIPTS_FETCH]
await self._handler.handle_get_receipts(peer, block_hashes)
elif isinstance(cmd, eth.GetNodeData):
# Only serve up to eth.MAX_STATE_FETCH items in every request.
node_hashes = cast(List[Hash32], msg)[:eth.MAX_STATE_FETCH]
await self._handler.handle_get_node_data(peer, node_hashes)
elif isinstance(cmd, eth.Transactions):
# Transactions msgs are handled by our TxPool service.
pass
elif isinstance(cmd, eth.NodeData):
# When doing a chain sync we never send GetNodeData requests, so peers should not send
# us NodeData msgs.
self.logger.warn("Unexpected NodeData msg from %s, disconnecting", peer)
await peer.disconnect(DisconnectReason.bad_protocol)
else:
self.logger.debug("Ignoring %s message from %s", cmd, peer)
self.logger.debug("%s msg not handled yet, need to be implemented", cmd)

async def _handle_new_block(self, peer: ETHPeer, msg: Dict[str, Any]) -> None:
self._sync_requests.put_nowait(peer)
Expand All @@ -515,7 +537,7 @@ async def _handle_block_receipts(self,

async def _handle_block_bodies(self,
peer: ETHPeer,
bodies: List[eth.BlockBody]) -> None:
bodies: List[BlockBody]) -> None:
self.logger.debug("Got Bodies for %d blocks from %s", len(bodies), peer)
loop = asyncio.get_event_loop()
iterator = map(make_trie_root_and_nodes, [body.transactions for body in bodies])
Expand All @@ -542,6 +564,7 @@ async def _handle_get_block_headers(
headers = await self._handler.lookup_headers(
header_request['block_number_or_hash'], header_request['max_headers'],
header_request['skip'], header_request['reverse'])
self.logger.trace("Replying to %s with %d headers", peer, len(headers))
peer.sub_proto.send_block_headers(headers)


Expand All @@ -553,31 +576,11 @@ class RegularChainSyncer(FastChainSyncer):
"""
_exit_on_sync_complete = False

async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
msg: protocol._DecodedMsgType) -> None:
peer = cast(ETHPeer, peer)
if isinstance(cmd, eth.BlockHeaders):
self._handle_block_headers(tuple(cast(Tuple[BlockHeader, ...], msg)))
elif isinstance(cmd, eth.BlockBodies):
await self._handle_block_bodies(peer, list(cast(Tuple[eth.BlockBody], msg)))
elif isinstance(cmd, eth.NewBlock):
await self._handle_new_block(peer, cast(Dict[str, Any], msg))
elif isinstance(cmd, eth.GetBlockHeaders):
await self._handle_get_block_headers(peer, cast(Dict[str, Any], msg))
elif isinstance(cmd, eth.GetBlockBodies):
# Only serve up to eth.MAX_BODIES_FETCH items in every request.
block_hashes = cast(List[Hash32], msg)[:eth.MAX_BODIES_FETCH]
await self._handler.handle_get_block_bodies(peer, cast(List[Hash32], msg))
elif isinstance(cmd, eth.GetReceipts):
# Only serve up to eth.MAX_RECEIPTS_FETCH items in every request.
block_hashes = cast(List[Hash32], msg)[:eth.MAX_RECEIPTS_FETCH]
await self._handler.handle_get_receipts(peer, block_hashes)
elif isinstance(cmd, eth.GetNodeData):
# Only serve up to eth.MAX_STATE_FETCH items in every request.
node_hashes = cast(List[Hash32], msg)[:eth.MAX_STATE_FETCH]
await self._handler.handle_get_node_data(peer, node_hashes)
else:
self.logger.debug("%s msg not handled yet, need to be implemented", cmd)
async def _handle_block_receipts(
self, peer: ETHPeer, receipts_by_block: List[List[eth.Receipt]]) -> None:
# When doing a regular sync we never request receipts.
self.logger.warn("Unexpected BlockReceipts msg from %s, disconnecting", peer)
await peer.disconnect(DisconnectReason.bad_protocol)

async def _process_headers(
self, peer: HeaderRequestingPeer, headers: Tuple[BlockHeader, ...]) -> int:
Expand All @@ -599,7 +602,7 @@ async def _process_headers(
transactions: List[BaseTransaction] = []
uncles: List[BlockHeader] = []
else:
body = cast(eth.BlockBody, downloaded_parts[_body_key(header)])
body = cast(BlockBody, downloaded_parts[_body_key(header)])
tx_class = block_class.get_transaction_class()
transactions = [tx_class.from_base_transaction(tx)
for tx in body.transactions]
Expand All @@ -624,6 +627,7 @@ def __init__(self, db: 'AsyncHeaderDB', logger: TraceLogger, token: CancelToken)
self.cancel_token = token

async def handle_get_block_bodies(self, peer: ETHPeer, block_hashes: List[Hash32]) -> None:
self.logger.trace("%s requested bodies for %d blocks", peer, len(block_hashes))
chaindb = cast('AsyncChainDB', self.db)
bodies = []
for block_hash in block_hashes:
Expand All @@ -636,9 +640,11 @@ async def handle_get_block_bodies(self, peer: ETHPeer, block_hashes: List[Hash32
chaindb.coro_get_block_transactions(header, BaseTransactionFields))
uncles = await self.wait(chaindb.coro_get_block_uncles(header.uncles_hash))
bodies.append(BlockBody(transactions, uncles))
self.logger.trace("Replying to %s with %d block bodies", peer, len(bodies))
peer.sub_proto.send_block_bodies(bodies)

async def handle_get_receipts(self, peer: ETHPeer, block_hashes: List[Hash32]) -> None:
self.logger.trace("%s requested receipts for %d blocks", peer, len(block_hashes))
chaindb = cast('AsyncChainDB', self.db)
receipts = []
for block_hash in block_hashes:
Expand All @@ -650,9 +656,11 @@ async def handle_get_receipts(self, peer: ETHPeer, block_hashes: List[Hash32]) -
continue
block_receipts = await self.wait(chaindb.coro_get_receipts(header, Receipt))
receipts.append(block_receipts)
self.logger.trace("Replying to %s with receipts for %d blocks", peer, len(receipts))
peer.sub_proto.send_receipts(receipts)

async def handle_get_node_data(self, peer: ETHPeer, node_hashes: List[Hash32]) -> None:
self.logger.trace("%s requested %d trie nodes", peer, len(node_hashes))
chaindb = cast('AsyncChainDB', self.db)
nodes = []
for node_hash in node_hashes:
Expand All @@ -662,6 +670,7 @@ async def handle_get_node_data(self, peer: ETHPeer, node_hashes: List[Hash32]) -
self.logger.debug("%s asked for a trie node we don't have: %s", peer, node_hash)
continue
nodes.append(node)
self.logger.trace("Replying to %s with %d trie nodes", peer, len(nodes))
peer.sub_proto.send_node_data(nodes)

async def lookup_headers(self, block_number_or_hash: Union[int, bytes], max_headers: int,
Expand Down Expand Up @@ -731,7 +740,7 @@ async def _generate_available_headers(


class DownloadedBlockPart(NamedTuple):
part: Union[eth.BlockBody, List[Receipt]]
part: Union[BlockBody, List[Receipt]]
unique_key: Union[bytes, Tuple[bytes, bytes]]


Expand Down
33 changes: 17 additions & 16 deletions p2p/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ async def do_p2p_handshake(self) -> None:
# Peers sometimes send a disconnect msg before they send the initial P2P handshake.
raise HandshakeFailure("{} disconnected before completing handshake: {}".format(
self, msg['reason_name']))
self.process_p2p_handshake(cmd, msg)
await self.process_p2p_handshake(cmd, msg)

@property
async def genesis(self) -> BlockHeader:
Expand Down Expand Up @@ -393,16 +393,17 @@ def process_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> N
else:
self.handle_sub_proto_msg(cmd, msg)

def process_p2p_handshake(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
async def process_p2p_handshake(
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
msg = cast(Dict[str, Any], msg)
if not isinstance(cmd, Hello):
self.disconnect(DisconnectReason.bad_protocol)
await self.disconnect(DisconnectReason.bad_protocol)
raise HandshakeFailure("Expected a Hello msg, got {}, disconnecting".format(cmd))
remote_capabilities = msg['capabilities']
try:
self.sub_proto = self.select_sub_protocol(remote_capabilities)
except NoMatchingPeerCapabilities:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"No matching capabilities between us ({}) and {} ({}), disconnecting".format(
self.capabilities, self.remote, remote_capabilities))
Expand Down Expand Up @@ -474,9 +475,11 @@ def send(self, header: bytes, body: bytes) -> None:
self.logger.trace("Sending msg with cmd_id: %s", cmd_id)
self.writer.write(self.encrypt(header, body))

def disconnect(self, reason: DisconnectReason) -> None:
async def disconnect(self, reason: DisconnectReason) -> None:
"""Send a disconnect msg to the remote node and stop this Peer.

Also awaits for self.cancel() to ensure any pending tasks are cleaned up.

:param reason: An item from the DisconnectReason enum.
"""
if not isinstance(reason, DisconnectReason):
Expand All @@ -485,6 +488,8 @@ def disconnect(self, reason: DisconnectReason) -> None:
self.logger.debug("Disconnecting from remote peer; reason: %s", reason.name)
self.base_protocol.send_disconnect(reason.value)
self.close()
if self.is_running:
await self.cancel()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @carver as you added a wait_disconnect method in your other PR and I think this is a better pattern.


def select_sub_protocol(self, remote_capabilities: List[Tuple[bytes, int]]
) -> protocol.Protocol:
Expand Down Expand Up @@ -537,18 +542,18 @@ async def send_sub_proto_handshake(self) -> None:
async def process_sub_proto_handshake(
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
if not isinstance(cmd, (les.Status, les.StatusV2)):
self.disconnect(DisconnectReason.subprotocol_error)
await self.disconnect(DisconnectReason.subprotocol_error)
raise HandshakeFailure(
"Expected a LES Status msg, got {}, disconnecting".format(cmd))
msg = cast(Dict[str, Any], msg)
if msg['networkId'] != self.network_id:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"{} network ({}) does not match ours ({}), disconnecting".format(
self, msg['networkId'], self.network_id))
genesis = await self.genesis
if msg['genesisHash'] != genesis.hash:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"{} genesis ({}) does not match ours ({}), disconnecting".format(
self, encode_hex(msg['genesisHash']), genesis.hex_hash))
Expand Down Expand Up @@ -628,18 +633,18 @@ async def send_sub_proto_handshake(self) -> None:
async def process_sub_proto_handshake(
self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
if not isinstance(cmd, eth.Status):
self.disconnect(DisconnectReason.subprotocol_error)
await self.disconnect(DisconnectReason.subprotocol_error)
raise HandshakeFailure(
"Expected a ETH Status msg, got {}, disconnecting".format(cmd))
msg = cast(Dict[str, Any], msg)
if msg['network_id'] != self.network_id:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"{} network ({}) does not match ours ({}), disconnecting".format(
self, msg['network_id'], self.network_id))
genesis = await self.genesis
if msg['genesis_hash'] != genesis.hash:
self.disconnect(DisconnectReason.useless_peer)
await self.disconnect(DisconnectReason.useless_peer)
raise HandshakeFailure(
"{} genesis ({}) does not match ours ({}), disconnecting".format(
self, encode_hex(msg['genesis_hash']), genesis.hex_hash))
Expand Down Expand Up @@ -770,12 +775,8 @@ async def _run(self) -> None:

async def stop_all_peers(self) -> None:
self.logger.info("Stopping all peers ...")

peers = self.connected_nodes.values()
for peer in peers:
peer.disconnect(DisconnectReason.client_quitting)

await asyncio.gather(*[peer.cancel() for peer in peers])
await asyncio.gather(*[peer.disconnect(DisconnectReason.client_quitting) for peer in peers])

async def _cleanup(self) -> None:
await self.stop_all_peers()
Expand Down
2 changes: 1 addition & 1 deletion p2p/sharding_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def process_sub_proto_handshake(self,
cmd: Command,
msg: protocol._DecodedMsgType) -> None:
if not isinstance(cmd, Status):
self.disconnect(DisconnectReason.subprotocol_error)
await self.disconnect(DisconnectReason.subprotocol_error)
raise HandshakeFailure("Expected status msg, got {}, disconnecting".format(cmd))

async def _get_headers_at_chain_split(
Expand Down
12 changes: 9 additions & 3 deletions p2p/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,17 @@ async def _handle_msg(
elif isinstance(cmd, eth.GetBlockHeaders):
await self._handle_get_block_headers(peer, cast(Dict[str, Any], msg))
elif isinstance(cmd, eth.GetBlockBodies):
await self._handler.handle_get_block_bodies(peer, cast(List[Hash32], msg))
# Only serve up to eth.MAX_BODIES_FETCH items in every request.
block_hashes = cast(List[Hash32], msg)[:eth.MAX_BODIES_FETCH]
await self._handler.handle_get_block_bodies(peer, block_hashes)
elif isinstance(cmd, eth.GetReceipts):
await self._handler.handle_get_receipts(peer, cast(List[Hash32], msg))
# Only serve up to eth.MAX_RECEIPTS_FETCH items in every request.
block_hashes = cast(List[Hash32], msg)[:eth.MAX_RECEIPTS_FETCH]
await self._handler.handle_get_receipts(peer, block_hashes)
elif isinstance(cmd, eth.GetNodeData):
await self._handler.handle_get_node_data(peer, cast(List[Hash32], msg))
# Only serve up to eth.MAX_STATE_FETCH items in every request.
node_hashes = cast(List[Hash32], msg)[:eth.MAX_STATE_FETCH]
await self._handler.handle_get_node_data(peer, node_hashes)
else:
self.logger.warn("%s not handled during StateSync, must be implemented", cmd)

Expand Down
4 changes: 2 additions & 2 deletions trinity/plugins/builtin/tx_pool/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def _run(self) -> None:

async def _handle_tx(self, peer: ETHPeer, txs: List[BaseTransactionFields]) -> None:

self.logger.debug('Received transactions from %r: %r', peer, txs)
self.logger.trace('Received transactions from %r: %r', peer, txs)

self._add_txs_to_bloom(peer, txs)

Expand All @@ -87,7 +87,7 @@ async def _handle_tx(self, peer: ETHPeer, txs: List[BaseTransactionFields]) -> N
if len(filtered_tx) == 0:
continue

self.logger.debug(
self.logger.trace(
'Sending transactions to %r: %r',
receiving_peer,
filtered_tx
Expand Down