Skip to content

Commit

Permalink
Implement round trip awaitable API for retrieving block headers
Browse files Browse the repository at this point in the history
  • Loading branch information
pipermerriam committed Jul 27, 2018
1 parent ab94167 commit 5f6bb64
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 36 deletions.
31 changes: 5 additions & 26 deletions p2p/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from p2p import les
from p2p.cancellable import CancellableMixin
from p2p.constants import MAX_REORG_DEPTH, SEAL_CHECK_RANDOM_SAMPLE_RATE
from p2p.exceptions import NoEligiblePeers, OperationCancelled
from p2p.exceptions import NoEligiblePeers, OperationCancelled, ValidationError
from p2p.p2p_proto import DisconnectReason
from p2p.peer import BasePeer, ETHPeer, LESPeer, HeaderRequest, PeerPool, PeerSubscriber
from p2p.rlp import BlockBody
Expand Down Expand Up @@ -91,7 +91,6 @@ def __init__(self,
self._syncing = False
self._sync_complete = asyncio.Event()
self._sync_requests: asyncio.Queue[HeaderRequestingPeer] = asyncio.Queue()
self._new_headers: asyncio.Queue[Tuple[BlockHeader, ...]] = asyncio.Queue()
self._executor = get_asyncio_executor()

@property
Expand Down Expand Up @@ -207,7 +206,7 @@ async def _sync(self, peer: HeaderRequestingPeer) -> None:
self.logger.warn("Timeout waiting for header batch from %s, aborting sync", peer)
await peer.disconnect(DisconnectReason.timeout)
break
except ValueError as err:
except ValidationError as err:
self.logger.warn(
"Invalid header response sent by peer %s disconnecting: %s",
peer, err,
Expand Down Expand Up @@ -253,24 +252,14 @@ async def _fetch_missing_headers(
self, peer: HeaderRequestingPeer, start_at: int) -> Tuple[BlockHeader, ...]:
"""Fetch a batch of headers starting at start_at and return the ones we're missing."""
self.logger.debug("Fetching chain segment starting at #%d", start_at)
request = peer.request_block_headers(

headers = peer.get_block_headers(
start_at,
peer.max_headers_fetch,
skip=0,
reverse=False,
)

# Pass the peer's token to self.wait() because we want to abort if either we
# or the peer terminates.
headers = tuple(await self.wait(
self._new_headers.get(),
token=peer.cancel_token,
timeout=self._reply_timeout))

# check that the response headers are a valid match for our
# requested headers.
request.validate_headers(headers)

# the inner list comprehension is required to get python to evaluate
# the asynchronous comprehension
missing_headers = tuple([
Expand All @@ -287,14 +276,6 @@ async def _fetch_missing_headers(
)
return headers

def _handle_block_headers(self, headers: Tuple[BlockHeader, ...]) -> None:
if not headers:
self.logger.warn("Got an empty BlockHeaders msg")
return
self.logger.debug(
"Got BlockHeaders from %d to %d", headers[0].block_number, headers[-1].block_number)
self._new_headers.put_nowait(headers)

@abstractmethod
async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
msg: protocol._DecodedMsgType) -> None:
Expand Down Expand Up @@ -538,9 +519,7 @@ def request_receipts(self, target_td: int, headers: List[BlockHeader]) -> int:
async def _handle_msg(self, peer: HeaderRequestingPeer, cmd: protocol.Command,
msg: protocol._DecodedMsgType) -> None:
peer = cast(ETHPeer, peer)
if isinstance(cmd, eth.BlockHeaders):
self._handle_block_headers(tuple(cast(Tuple[BlockHeader, ...], msg)))
elif isinstance(cmd, eth.BlockBodies):
if isinstance(cmd, eth.BlockBodies):
await self._handle_block_bodies(peer, list(cast(Tuple[BlockBody], msg)))
elif isinstance(cmd, eth.Receipts):
await self._handle_block_receipts(peer, cast(List[List[Receipt]], msg))
Expand Down
7 changes: 7 additions & 0 deletions p2p/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,10 @@ class NoInternalAddressMatchesDevice(BaseP2PError):
def __init__(self, *args: Any, device_hostname: str=None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.device_hostname = device_hostname


class ValidationError(BaseP2PError):
"""
Raised when something does not pass a validation check.
"""
pass
80 changes: 70 additions & 10 deletions p2p/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
Dict,
Iterator,
List,
NamedTuple,
TYPE_CHECKING,
Tuple,
Type,
Expand Down Expand Up @@ -60,8 +59,10 @@

from p2p import auth
from p2p import ecies
from p2p.kademlia import Address, Node
from p2p import eth
from p2p import les
from p2p import protocol
from p2p.kademlia import Address, Node
from p2p.exceptions import (
BadAckMessage,
DAOForkCheckFailure,
Expand All @@ -76,6 +77,7 @@
UnexpectedMessage,
UnknownProtocolCommand,
UnreachablePeer,
ValidationError,
)
from p2p.service import BaseService
from p2p.utils import (
Expand All @@ -85,8 +87,6 @@
sxor,
time_since,
)
from p2p import eth
from p2p import les
from p2p.p2p_proto import (
Disconnect,
DisconnectReason,
Expand Down Expand Up @@ -144,12 +144,39 @@ async def handshake(remote: Node,
return peer


class HeaderRequest(NamedTuple):
class BaseRequest(ABC):
@abstractmethod
def validate_response(self, response: Any) -> None:
pass


class HeaderRequest(BaseRequest):
block_number_or_hash: BlockIdentifier
max_headers: int
skip: int
reverse: bool

def __init__(self,
block_number_or_hash: BlockIdentifier,
max_headers: int,
skip: int,
reverse: bool) -> None:
self.block_number_or_hash = block_number_or_hash
self.max_headers = max_headers
self.skip = skip
self.reverse = reverse

def validate_response(self, response: Any) -> None:
"""
Core `Request` API used for validation.
"""
if not isinstance(response, tuple):
raise ValidationError("Response to `HeaderRequest` must be a tuple")
elif not all(isinstance(item, BlockHeader) for item in response):
raise ValidationError("Response must be a tuple of `BlockHeader` objects")

return self.validate_headers(cast(Tuple[BlockHeader, ...], response))

def generate_block_numbers(self,
block_number: BlockNumber=None) -> Tuple[BlockNumber, ...]:
if block_number is None and not self.is_numbered:
Expand Down Expand Up @@ -188,7 +215,7 @@ def validate_headers(self,
elif not self.is_numbered:
first_header = headers[0]
if first_header.hash != self.block_number_or_hash:
raise ValueError(
raise ValidationError(
"Returned headers cannot be matched to header request. "
"Expected first header to have hash of {0} but instead got "
"{1}.".format(
Expand All @@ -213,7 +240,7 @@ def validate_sequence(self, block_numbers: Tuple[BlockNumber, ...]) -> None:
# check for numbers that should not be present.
unexpected_numbers = set(block_numbers).difference(expected_numbers)
if unexpected_numbers:
raise ValueError(
raise ValidationError(
'Unexpected numbers: {0}'.format(unexpected_numbers))

# check that the numbers are correctly ordered.
Expand All @@ -222,7 +249,7 @@ def validate_sequence(self, block_numbers: Tuple[BlockNumber, ...]) -> None:
reverse=self.reverse,
))
if block_numbers != expected_order:
raise ValueError(
raise ValidationError(
'Returned headers are not correctly ordered.\n'
'Expected: {0}\n'
'Got : {1}\n'.format(expected_order, block_numbers)
Expand All @@ -236,7 +263,7 @@ def validate_sequence(self, block_numbers: Tuple[BlockNumber, ...]) -> None:
if value == number:
break
else:
raise ValueError(
raise ValidationError(
'Returned headers contain an unexpected block number.\n'
'Unexpected Number: {0}\n'
'Expected Numbers : {1}'.format(number, expected_numbers)
Expand All @@ -255,6 +282,11 @@ class BasePeer(BaseService):
head_td: int = None
head_hash: Hash32 = None

# TODO: Instead of a fixed timeout, we should instead monitor response
# times for the peer and adjust our timeout accordingly
_response_timeout = 60
pending_requests: Dict[Type[protocol.Command], asyncio.Future]

def __init__(self,
remote: Node,
privkey: datatypes.PrivateKey,
Expand All @@ -281,6 +313,8 @@ def __init__(self,
self.start_time = datetime.datetime.now()
self.received_msgs: Dict[protocol.Command, int] = collections.defaultdict(int)

self.pending_requests = {}

self.egress_mac = egress_mac
self.ingress_mac = ingress_mac
# FIXME: Yes, the encryption is insecure, see: https://github.com/ethereum/devp2p/issues/32
Expand Down Expand Up @@ -686,6 +720,8 @@ def max_headers_fetch(self) -> int:
return eth.MAX_HEADERS_FETCH

def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgType) -> None:
cmd_type = type(cmd)

if isinstance(cmd, eth.NewBlock):
msg = cast(Dict[str, Any], msg)
header, _, _ = msg['block']
Expand All @@ -694,6 +730,16 @@ def handle_sub_proto_msg(self, cmd: protocol.Command, msg: protocol._DecodedMsgT
if actual_td > self.head_td:
self.head_hash = actual_head
self.head_td = actual_td
elif cmd_type in self.pending_requests:
request, future = self.pending_requests[cmd_type]
try:
request.validate_response(msg)
except ValidationError:
pass
else:
future.set_result(msg)
self.pending_requests.pop(cmd_type)

super().handle_sub_proto_msg(cmd, msg)

async def send_sub_proto_handshake(self) -> None:
Expand Down Expand Up @@ -741,6 +787,20 @@ def request_block_headers(self,
)
return request

async def wait_for_block_headers(self, request: HeaderRequest) -> Tuple[BlockHeader, ...]:
future = asyncio.Future()
self.pending_requests[eth.BlockHeaders] = (request, future)
response = self.wait(future, timeout=self._response_timeout)
return response

async def get_block_headers(self,
block_number_or_hash: BlockIdentifier,
max_headers: int = None,
skip: int = 0,
reverse: bool = True) -> Tuple[BlockHeader, ...]:
request = self.request_block_headers(block_number_or_hash, max_headers, skip, reverse)
return await self.wait_for_block_headers(request)


class PeerSubscriber(ABC):
_msg_queue: 'asyncio.Queue[PEER_MSG_TYPE]' = None
Expand Down Expand Up @@ -981,7 +1041,7 @@ async def ensure_same_side_on_dao_fork(

try:
request.validate_headers(headers)
except ValueError as err:
except ValidationError as err:
raise DAOForkCheckFailure(
"Invalid header response during DAO fork check: {}".format(err)
)
Expand Down
26 changes: 26 additions & 0 deletions tests/p2p/test_peer_block_header_request_and_response_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import asyncio

import pytest

from eth.rlp.headers import BlockHeader

from p2p.peer import ETHPeer
from peer_helpers import (
get_directly_linked_peers,
)


@pytest.mark.asyncio
async def test_eth_peer_get_headers(request, event_loop):
peer, remote = await get_directly_linked_peers(request, event_loop, peer1_class=ETHPeer, peer2_class=ETHPeer)
header = BlockHeader(difficulty=100, block_number=0, gas_limit=3000000)

async def send_headers():
remote.sub_proto.send_block_headers((header,))
await asyncio.sleep(0)

asyncio.ensure_future(send_headers())
response = await peer.get_block_headers(0, 1, 0, False)

assert len(response) == 1
assert response[0] == header

0 comments on commit 5f6bb64

Please sign in to comment.