Skip to content

Commit

Permalink
Merge pull request #94 from pipermerriam/piper/implement-JSON-RPC-pin…
Browse files Browse the repository at this point in the history
…g-endpoint

Piper/implement json rpc ping endpoint
  • Loading branch information
pipermerriam authored Sep 30, 2020
2 parents 4e841e5 + 76432dc commit 64485c3
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 70 deletions.
13 changes: 13 additions & 0 deletions ddht/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from socket import inet_ntoa
from typing import NamedTuple

from eth_enr import ENRAPI
from eth_enr.constants import IP_V4_ADDRESS_ENR_KEY, UDP_PORT_ENR_KEY


class Endpoint(NamedTuple):
ip_address: bytes
port: int

def __str__(self) -> str:
return f"{inet_ntoa(self.ip_address)}:{self.port}"

@classmethod
def from_enr(self, enr: ENRAPI) -> "Endpoint":
try:
ip_address = enr[IP_V4_ADDRESS_ENR_KEY]
port = enr[UDP_PORT_ENR_KEY]
except KeyError:
raise Exception("Missing endpoint address information: ")

return Endpoint(ip_address, port)
4 changes: 2 additions & 2 deletions ddht/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ async def execute_rpc(self, request: RPCRequest) -> str:
self.logger.error("Error handling request: %s error: %s", request, err)
self.logger.debug("Error handling request: %s", request, exc_info=True)
response = generate_error_response(request, f"Unexpected Error: {err}")
finally:
return json.dumps(response)

return json.dumps(response)

async def serve(self, ipc_path: pathlib.Path) -> None:
self.logger.info("Starting RPC server over IPC socket: %s", ipc_path)
Expand Down
71 changes: 69 additions & 2 deletions ddht/tools/web3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Callable, Mapping, NamedTuple, Tuple
import ipaddress
from typing import Any, Callable, List, Mapping, NamedTuple, Tuple, Union

from eth_utils import add_0x_prefix, encode_hex, remove_0x_prefix

try:
import web3 # noqa: F401
Expand All @@ -7,14 +10,16 @@


from eth_enr import ENR, ENRAPI
from eth_typing import NodeID
from eth_typing import HexStr, NodeID
from eth_utils import decode_hex
from web3.method import Method
from web3.module import ModuleV2
from web3.types import RPCEndpoint

from ddht.rpc_handlers import BucketInfo as BucketInfoDict
from ddht.rpc_handlers import NodeInfoResponse, TableInfoResponse
from ddht.typing import AnyIPAddress
from ddht.v5_1.rpc_handlers import PongResponse


class NodeInfo(NamedTuple):
Expand Down Expand Up @@ -69,17 +74,79 @@ def from_rpc_response(cls, response: TableInfoResponse) -> "TableInfo":
)


class PongPayload(NamedTuple):
enr_seq: int
packet_ip: AnyIPAddress
packet_port: int

@classmethod
def from_rpc_response(cls, response: PongResponse) -> "PongPayload":
return cls(
enr_seq=response["enr_seq"],
packet_ip=ipaddress.ip_address(response["packet_ip"]),
packet_port=response["packet_port"],
)


class RPC:
nodeInfo = RPCEndpoint("discv5_nodeInfo")
routingTableInfo = RPCEndpoint("discv5_routingTableInfo")

ping = RPCEndpoint("discv5_ping")


def ping_munger(
module: Any, identifier: Union[ENRAPI, str, bytes, NodeID, HexStr]
) -> List[str]:
"""
See: https://github.com/ethereum/web3.py/blob/002151020cecd826a694ded2fdc10cc70e73e636/web3/method.py#L77 # noqa: E501
Normalizes the any of the following inputs into the appropriate payload for
the ``discv5_ping` JSON-RPC API endpoint.
- An ENR object
- The string representation of an ENR
- A NodeID in the form of a bytestring
- A NodeID in the form of a hex string
- An ENode URI
Throws a ``ValueError`` if the input cannot be matched to one of these
formats.
"""
if isinstance(identifier, ENRAPI):
return [repr(identifier)]
elif isinstance(identifier, bytes):
if len(identifier) == 32:
return [encode_hex(identifier)]
raise ValueError(f"Unrecognized node identifier: {identifier!r}")
elif isinstance(identifier, str):
if identifier.startswith("enode://") or identifier.startswith("enr:"):
return [identifier]
elif len(remove_0x_prefix(HexStr(identifier))) == 64:
return [add_0x_prefix(HexStr(identifier))]
else:
raise ValueError(f"Unrecognized node identifier: {identifier}")
else:
raise ValueError(f"Unrecognized node identifier: {identifier}")


# TODO: why does mypy think ModuleV2 is of `Any` type?
class DiscoveryV5Module(ModuleV2): # type: ignore
"""
A web3.py module that exposes high level APIs for interacting with the
discovery v5 network.
"""

get_node_info: Method[Callable[[], NodeInfo]] = Method(
RPC.nodeInfo, result_formatters=lambda method: NodeInfo.from_rpc_response,
)
get_routing_table_info: Method[Callable[[], TableInfo]] = Method(
RPC.routingTableInfo,
result_formatters=lambda method: TableInfo.from_rpc_response,
)

ping: Method[Callable[[Union[NodeID, ENRAPI, HexStr, str]], PongPayload]] = Method(
RPC.ping,
result_formatters=lambda method: PongPayload.from_rpc_response,
mungers=[ping_munger],
)
7 changes: 6 additions & 1 deletion ddht/v5_1/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from eth_enr import ENRDB, ENRManager, default_identity_scheme_registry
from eth_keys import keys
from eth_utils import encode_hex
from eth_utils.toolz import merge
import trio

from ddht._utils import generate_node_key_file, read_node_key_file
Expand All @@ -17,6 +18,7 @@
from ddht.v5_1.events import Events
from ddht.v5_1.messages import v51_registry
from ddht.v5_1.network import Network
from ddht.v5_1.rpc_handlers import get_v51_rpc_handlers

ENR_DATABASE_DIR_NAME = "enr-db"

Expand Down Expand Up @@ -94,7 +96,10 @@ async def run(self) -> None:
network = Network(client=client, bootnodes=bootnodes,)

if self._boot_info.is_rpc_enabled:
handlers = get_core_rpc_handlers(network.routing_table)
handlers = merge(
get_core_rpc_handlers(network.routing_table),
get_v51_rpc_handlers(network),
)
rpc_server = RPCServer(self._boot_info.ipc_path, handlers)
self.manager.run_daemon_child_service(rpc_server)

Expand Down
16 changes: 3 additions & 13 deletions ddht/v5_1/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from async_service import Service
from eth_enr import ENRAPI, ENRDatabaseAPI, ENRManagerAPI
from eth_enr.constants import IP_V4_ADDRESS_ENR_KEY, UDP_PORT_ENR_KEY
from eth_enr.exceptions import OldSequenceNumber
from eth_typing import NodeID
from eth_utils import to_tuple
Expand Down Expand Up @@ -295,7 +294,7 @@ async def _manage_routing_table(self) -> None:
for enr in self._bootnodes:
if enr.node_id == self.local_node_id:
continue
endpoint = self._endpoint_for_enr(enr)
endpoint = Endpoint.from_enr(enr)
nursery.start_soon(self._bond, enr.node_id, endpoint)

with trio.move_on_after(10):
Expand All @@ -312,7 +311,7 @@ async def _manage_routing_table(self) -> None:
target_node_id = NodeID(secrets.token_bytes(32))
found_enrs = await self.recursive_find_nodes(target_node_id)
for enr in found_enrs:
endpoint = self._endpoint_for_enr(enr)
endpoint = Endpoint.from_enr(enr)
nursery.start_soon(self._bond, enr.node_id, endpoint)

async def _pong_when_pinged(self) -> None:
Expand Down Expand Up @@ -381,18 +380,9 @@ async def _serve_find_nodes(self) -> None:
#
# Utility
#
def _endpoint_for_enr(self, enr: ENRAPI) -> Endpoint:
try:
ip_address = enr[IP_V4_ADDRESS_ENR_KEY]
port = enr[UDP_PORT_ENR_KEY]
except KeyError:
raise Exception("Missing endpoint address information: ")

return Endpoint(ip_address, port)

def _endpoint_for_node_id(self, node_id: NodeID) -> Endpoint:
enr = self.enr_db.get_enr(node_id)
return self._endpoint_for_enr(enr)
return Endpoint.from_enr(enr)


@to_tuple
Expand Down
111 changes: 111 additions & 0 deletions ddht/v5_1/rpc_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import ipaddress
from socket import inet_ntoa
from typing import Any, Iterable, Optional, Tuple, TypedDict

from eth_enr import ENR
from eth_typing import HexStr, NodeID
from eth_utils import decode_hex, is_hex, remove_0x_prefix, to_dict

from ddht.abc import RPCHandlerAPI
from ddht.endpoint import Endpoint
from ddht.rpc import RPCError, RPCHandler, RPCRequest
from ddht.v5_1.abc import NetworkAPI


class PongResponse(TypedDict):
enr_seq: int
packet_ip: str
packet_port: int


def is_hex_node_id(value: Any) -> bool:
return (
isinstance(value, str)
and is_hex(value)
and len(remove_0x_prefix(HexStr(value))) == 64
)


def validate_hex_node_id(value: Any) -> None:
if not is_hex_node_id(value):
raise RPCError(f"Invalid NodeID: {value}")


def is_endpoint(value: Any) -> bool:
if not isinstance(value, str):
return False
ip_address, _, port = value.rpartition(":")
try:
ipaddress.ip_address(ip_address)
except ValueError:
return False

if not port.isdigit():
return False

return True


def validate_endpoint(value: Any) -> None:
if not is_endpoint(value):
raise RPCError(f"Invalid Endpoint: {value}")


class PingHandler(RPCHandler[Tuple[NodeID, Optional[Endpoint]], PongResponse]):
def __init__(self, network: NetworkAPI) -> None:
self._network = network

def extract_params(self, request: RPCRequest) -> Tuple[NodeID, Optional[Endpoint]]:
try:
raw_params = request["params"]
except KeyError as err:
raise RPCError(f"Missiing call params: {err}")

if len(raw_params) != 1:
raise RPCError(
f"`ddht_ping` endpoint expects a single parameter: "
f"Got {len(raw_params)} params: {raw_params}"
)

value = raw_params[0]

node_id: NodeID
endpoint: Optional[Endpoint]

if is_hex_node_id(value):
node_id = NodeID(decode_hex(value))
endpoint = None
elif value.startswith("enode://"):
raw_node_id, _, raw_endpoint = value[8:].partition("@")

validate_hex_node_id(raw_node_id)
validate_endpoint(raw_endpoint)

node_id = NodeID(decode_hex(raw_node_id))

raw_ip_address, _, raw_port = raw_endpoint.partition(":")
ip_address = ipaddress.ip_address(raw_ip_address)
port = int(raw_port)
endpoint = Endpoint(ip_address.packed, port)
elif value.startswith("enr:"):
enr = ENR.from_repr(value)
node_id = enr.node_id
endpoint = Endpoint.from_enr(enr)
else:
raise RPCError(f"Unrecognized node identifier: {value}")

return node_id, endpoint

async def do_call(self, params: Tuple[NodeID, Optional[Endpoint]]) -> PongResponse:
node_id, endpoint = params
pong = await self._network.ping(node_id, endpoint=endpoint)
return PongResponse(
enr_seq=pong.enr_seq,
packet_ip=inet_ntoa(pong.packet_ip),
packet_port=pong.packet_port,
)


@to_dict
def get_v51_rpc_handlers(network: NetworkAPI) -> Iterable[Tuple[str, RPCHandlerAPI]]:
yield ("discv5_ping", PingHandler(network))
57 changes: 57 additions & 0 deletions tests/core/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import io
import itertools
import json

from eth_utils.toolz import take
import pytest
import trio

from ddht.rpc import RPCRequest, read_json


@pytest.fixture(name="make_raw_request")
async def _make_raw_request(ipc_path, rpc_server):
socket = await trio.open_unix_socket(str(ipc_path))
async with socket:
buffer = io.StringIO()

async def make_raw_request(raw_request: str):
with trio.fail_after(2):
data = raw_request.encode("utf8")
data_iter = iter(data)
while True:
chunk = bytes(take(1024, data_iter))
if chunk:
try:
await socket.send_all(chunk)
except trio.BrokenResourceError:
break
else:
break
return await read_json(socket.socket, buffer)

yield make_raw_request


@pytest.fixture(name="make_request")
async def _make_request(make_raw_request):
id_counter = itertools.count()

async def make_request(method, params=None):
if params is None:
params = []
request = RPCRequest(
jsonrpc="2.0", method=method, params=params, id=next(id_counter),
)
raw_request = json.dumps(request)

raw_response = await make_raw_request(raw_request)

if "error" in raw_response:
raise Exception(raw_response)
elif "result" in raw_response:
return raw_response["result"]
else:
raise Exception("Invariant")

yield make_request
Loading

0 comments on commit 64485c3

Please sign in to comment.