From b1ab3efe017219b26ea9eebb963ef5743f19c73b Mon Sep 17 00:00:00 2001 From: Piper Merriam Date: Tue, 8 Sep 2020 17:41:50 -0600 Subject: [PATCH 1/2] Implement core JSON-RPC aPI --- ddht/abc.py | 31 +++- ddht/app.py | 15 ++ ddht/boot_info.py | 15 ++ ddht/cli_parser.py | 14 ++ ddht/rpc.py | 253 +++++++++++++++++++++++++++ ddht/rpc_handlers.py | 92 ++++++++++ ddht/tools/factories/boot_info.py | 2 + ddht/tools/web3.py | 85 +++++++++ ddht/typing.py | 4 +- ddht/v5/app.py | 31 ++-- ddht/v5_1/app.py | 37 ++-- setup.py | 1 + tests/conftest.py | 5 + tests/core/test_core_rpc_handlers.py | 215 +++++++++++++++++++++++ tests/core/v5_1/test_network.py | 2 + tox.ini | 1 + 16 files changed, 760 insertions(+), 43 deletions(-) create mode 100644 ddht/app.py create mode 100644 ddht/rpc.py create mode 100644 ddht/rpc_handlers.py create mode 100644 ddht/tools/web3.py create mode 100644 tests/core/test_core_rpc_handlers.py diff --git a/ddht/abc.py b/ddht/abc.py index 754335dc..66216d2e 100644 --- a/ddht/abc.py +++ b/ddht/abc.py @@ -11,15 +11,18 @@ Optional, Tuple, Type, + TypedDict, TypeVar, ) +from async_service import ServiceAPI from eth_enr.abc import IdentitySchemeAPI from eth_typing import NodeID import trio from ddht.base_message import BaseMessage -from ddht.typing import IDNonce, SessionKeys +from ddht.boot_info import BootInfo +from ddht.typing import JSON, IDNonce, SessionKeys TAddress = TypeVar("TAddress", bound="AddressAPI") @@ -244,3 +247,29 @@ def register( self, handshake_scheme_class: Type[HandshakeSchemeAPI] ) -> Type[HandshakeSchemeAPI]: ... + + +class RPCRequest(TypedDict, total=False): + jsonrpc: str + method: str + params: List[Any] + id: int + + +class RPCResponse(TypedDict, total=False): + id: int + jsonrpc: str + result: JSON + error: str + + +class RPCHandlerAPI(ABC): + @abstractmethod + async def __call__(self, request: RPCRequest) -> RPCResponse: + ... + + +class ApplicationAPI(ServiceAPI): + @abstractmethod + def __init__(self, boot_info: BootInfo) -> None: + ... diff --git a/ddht/app.py b/ddht/app.py new file mode 100644 index 00000000..dab8f682 --- /dev/null +++ b/ddht/app.py @@ -0,0 +1,15 @@ +import logging + +from async_service import Service + +from ddht.abc import ApplicationAPI +from ddht.boot_info import BootInfo + + +class BaseApplication(Service, ApplicationAPI): + logger = logging.getLogger("ddht.DDHT") + + _boot_info: BootInfo + + def __init__(self, boot_info: BootInfo) -> None: + self._boot_info = boot_info diff --git a/ddht/boot_info.py b/ddht/boot_info.py index 9acb3dc2..b98f5f5b 100644 --- a/ddht/boot_info.py +++ b/ddht/boot_info.py @@ -26,6 +26,8 @@ class BootInfoKwargs(TypedDict, total=False): private_key: Optional[keys.PrivateKey] is_ephemeral: bool is_upnp_enabled: bool + is_rpc_enabled: bool + ipc_path: pathlib.Path def _cli_args_to_boot_info_kwargs(args: argparse.Namespace) -> BootInfoKwargs: @@ -76,6 +78,15 @@ def _cli_args_to_boot_info_kwargs(args: argparse.Namespace) -> BootInfoKwargs: else: private_key = None + ipc_path: pathlib.Path + + if args.ipc_path is not None: + ipc_path = args.ipc_path + else: + ipc_path = base_dir / "jsonrpc.ipc" + + is_rpc_enabled = args.disable_jsonrpc is not True + return BootInfoKwargs( protocol_version=protocol_version, base_dir=base_dir, @@ -85,6 +96,8 @@ def _cli_args_to_boot_info_kwargs(args: argparse.Namespace) -> BootInfoKwargs: private_key=private_key, is_ephemeral=is_ephemeral, is_upnp_enabled=is_upnp_enabled, + is_rpc_enabled=is_rpc_enabled, + ipc_path=ipc_path, ) @@ -98,6 +111,8 @@ class BootInfo: private_key: Optional[keys.PrivateKey] is_ephemeral: bool is_upnp_enabled: bool + is_rpc_enabled: bool + ipc_path: pathlib.Path @classmethod def from_cli_args(cls, args: Sequence[str]) -> "BootInfo": diff --git a/ddht/cli_parser.py b/ddht/cli_parser.py index 61639f9c..c569503e 100644 --- a/ddht/cli_parser.py +++ b/ddht/cli_parser.py @@ -25,6 +25,7 @@ ddht_parser = parser.add_argument_group("core") logging_parser = parser.add_argument_group("logging") network_parser = parser.add_argument_group("network") +jsonrpc_parser = parser.add_argument_group("jsonrpc") # @@ -109,3 +110,16 @@ def __call__( help="IP address to listen on", dest="bootnodes", ) + + +# +# JSON-RPC +# +jsonrpc_parser.add_argument( + "--disable-jsonrpc", help="Disable the JSON-RPC server", +) +jsonrpc_parser.add_argument( + "--ipc-path", + type=pathlib.Path, + help="Path where the IPC socket will be opened for serving JSON-RPC", +) diff --git a/ddht/rpc.py b/ddht/rpc.py new file mode 100644 index 00000000..f36d9889 --- /dev/null +++ b/ddht/rpc.py @@ -0,0 +1,253 @@ +from abc import abstractmethod +import collections +import io +import json +import logging +import pathlib +from typing import Any, Dict, Generic, Mapping, Optional, Tuple, TypeVar, cast + +from async_service import Service +from eth_utils import ValidationError +import trio + +from ddht.abc import RPCHandlerAPI, RPCRequest, RPCResponse +from ddht.typing import JSON + +NEW_LINE = "\n" + + +def strip_non_json_prefix(raw_request: str) -> Tuple[str, str]: + if raw_request and raw_request[0] != "{": + prefix, bracket, rest = raw_request.partition("{") + return prefix.strip(), bracket + rest + else: + return "", raw_request + + +logger = logging.getLogger("ddht.rpc") +decoder = json.JSONDecoder() + + +async def read_json( + socket: trio.socket.SocketType, + buffer: io.StringIO, + decoder: json.JSONDecoder = decoder, +) -> JSON: + request: JSON + + while True: + data = await socket.recv(1024) + buffer.write(data.decode()) + + bad_prefix, raw_request = strip_non_json_prefix(buffer.getvalue()) + if bad_prefix: + logger.info("Client started request with non json data: %r", bad_prefix) + await write_error(socket, f"Cannot parse json: {bad_prefix}") + continue + + try: + request, offset = decoder.raw_decode(raw_request) + except json.JSONDecodeError: + # invalid json request, keep reading data until a valid json is formed + if raw_request: + logger.debug( + "Invalid JSON, waiting for rest of message: %r", raw_request, + ) + else: + await trio.sleep(0.01) + continue + + # TODO: more efficient algorithm can be used here by + # manipulating the buffer such that we can seek back to the + # correct position for *new* data to come in. + buffer.seek(0) + buffer.write(raw_request[offset:]) + buffer.truncate() + + break + + return request + + +async def write_error(socket: trio.socket.SocketType, message: str) -> None: + json_error = json.dumps({"error": message}) + await socket.send(json_error.encode("utf8")) + + +def validate_request(request: Mapping[Any, Any]) -> None: + try: + version = request["jsonrpc"] + except KeyError as err: + raise ValidationError("Missing 'jsonrpc' key") from err + else: + if version != "2.0": + raise ValidationError(f"Invalid version: {version}") + + if "method" not in request: + raise ValidationError("Missing 'method' key") + if "params" in request: + if not isinstance(request["params"], list): + raise ValidationError("Missing 'method' key") + + +def generate_response( + request: RPCRequest, result: Any, error: Optional[str] +) -> RPCResponse: + response = RPCResponse( + id=request.get("id", -1), jsonrpc=request.get("jsonrpc", "2.0"), + ) + + if result is None and error is None: + raise TypeError("Must supply either result or error for JSON-RPC response") + if result is not None and error is not None: + raise TypeError( + "Must not supply both a result and an error for JSON-RPC response" + ) + elif result is not None: + response["result"] = result + elif error is not None: + response["error"] = str(error) + else: + raise Exception("Unreachable code path") + + return response + + +TParams = TypeVar("TParams") +TResult = TypeVar("TResult") + + +class RPCError(Exception): + ... + + +class RPCHandler(RPCHandlerAPI, Generic[TParams, TResult]): + async def __call__(self, request: RPCRequest) -> RPCResponse: + try: + params = self.extract_params(request) + result = await self.do_call(params) + except RPCError as err: + return self.generate_error_response(request, str(err)) + else: + return self.generate_success_response(request, result) + + @abstractmethod + async def do_call(self, params: TParams) -> TResult: + ... + + def extract_params(self, request: RPCRequest) -> TParams: + return request.get("params", []) # type: ignore + + def generate_success_response( + self, request: RPCRequest, result: Any + ) -> RPCResponse: + return generate_response(request, result, None) + + def generate_error_response(self, request: RPCRequest, error: str) -> RPCResponse: + return generate_response(request, None, error) + + +class UnknownMethodHandler(RPCHandlerAPI): + async def __call__(self, request: RPCRequest) -> RPCResponse: + return generate_response( + request, None, f"Unknown RPC method: {request['method']}" + ) + + +fallback_handler = UnknownMethodHandler() + + +class RPCServer(Service): + logger = logging.getLogger("alexandria.rpc.RPCServer") + _handlers: Dict[str, RPCHandlerAPI] + + def __init__( + self, ipc_path: pathlib.Path, handlers: Dict[str, RPCHandlerAPI] + ) -> None: + self.ipc_path = ipc_path + self._handlers = handlers + self._serving = trio.Event() + + async def wait_serving(self) -> None: + await self._serving.wait() + + async def run(self) -> None: + self.manager.run_daemon_task(self.serve, self.ipc_path) + try: + await self.manager.wait_finished() + finally: + self.ipc_path.unlink() + + async def execute_rpc(self, request: RPCRequest) -> str: + method = request["method"] + + self.logger.debug("RPCServer handling request: %s", method) + + handler = self._handlers.get(method, fallback_handler) + try: + response = await handler(request) + except Exception as err: + self.logger.error("Error handling request: %s error: %s", request, err) + self.logger.debug("Error handling request: %s", request, exc_info=True) + response = generate_response(request, None, f"Unexpected Error: {err}") + finally: + return json.dumps(response) + + async def serve(self, ipc_path: pathlib.Path) -> None: + self.logger.info("Starting RPC server over IPC socket: %s", ipc_path) + + with trio.socket.socket(trio.socket.AF_UNIX, trio.socket.SOCK_STREAM) as sock: + # TODO: unclear if the following stuff is necessary: + # ################################################### + # These options help fix an issue with the socket reporting itself + # already being used since it accepts many client connection. + # https://stackoverflow.com/questions/6380057/python-binding-socket-address-already-in-use + # sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # ################################################### + await sock.bind(str(ipc_path)) + # Allow up to 10 pending connections. + sock.listen(10) + + self._serving.set() + + while self.manager.is_running: + conn, addr = await sock.accept() + self.logger.debug("Server accepted connection: %r", addr) + self.manager.run_task(self._handle_connection, conn) + + async def _handle_connection(self, socket: trio.socket.SocketType) -> None: + buffer = io.StringIO() + + with socket: + while True: + request = await read_json(socket, buffer) + + if not isinstance(request, collections.abc.Mapping): + logger.debug("Invalid payload: %s", type(request)) + await write_error(socket, "Invalid Request: not a mapping") + continue + + if not request: + self.logger.debug("Client sent empty request") + await write_error(socket, "Invalid Request: empty") + continue + + try: + validate_request(request) + except ValidationError as err: + await write_error(socket, str(err)) + continue + + try: + result = await self.execute_rpc(cast(RPCRequest, request)) + except Exception as e: + self.logger.exception("Unrecognized exception while executing RPC") + await write_error(socket, "unknown failure: " + str(e)) + else: + if not result.endswith(NEW_LINE): + result += NEW_LINE + + try: + await socket.send(result.encode()) + except BrokenPipeError: + break diff --git a/ddht/rpc_handlers.py b/ddht/rpc_handlers.py new file mode 100644 index 00000000..f8ae9b7c --- /dev/null +++ b/ddht/rpc_handlers.py @@ -0,0 +1,92 @@ +from typing import Iterator, Mapping, Tuple, TypedDict + +from eth_enr import ENRAPI +from eth_typing import HexStr +from eth_utils import encode_hex, to_dict + +from ddht.abc import RoutingTableAPI, RPCHandlerAPI, RPCRequest +from ddht.rpc import RPCError, RPCHandler + + +class BucketInfo(TypedDict): + idx: int + nodes: Tuple[HexStr, ...] + replacement_cache: Tuple[HexStr, ...] + is_full: bool + + +class TableInfoResponse(TypedDict): + center_node_id: HexStr + num_buckets: int + bucket_size: int + buckets: Mapping[int, BucketInfo] + + +class RoutingTableInfoHandler(RPCHandler[None, TableInfoResponse]): + def __init__(self, routing_table: RoutingTableAPI) -> None: + self._routing_table = routing_table + + def extract_params(self, request: RPCRequest) -> None: + if request.get("params"): + raise RPCError(f"Unexpected RPC params: {request['params']}",) + return None + + async def do_call(self, params: None) -> TableInfoResponse: + stats = TableInfoResponse( + center_node_id=encode_hex(self._routing_table.center_node_id), + num_buckets=len(self._routing_table.buckets), + bucket_size=self._routing_table.bucket_size, + buckets=self._bucket_stats(), + ) + return stats + + @to_dict + def _bucket_stats(self) -> Iterator[Tuple[int, BucketInfo]]: + buckets_and_replacement_caches = zip( + self._routing_table.buckets, self._routing_table.replacement_caches, + ) + for idx, (bucket, replacement_cache) in enumerate( + buckets_and_replacement_caches, 1 + ): + if bucket: + yield ( + idx, + BucketInfo( + idx=idx, + nodes=tuple(encode_hex(node_id) for node_id in bucket), + replacement_cache=tuple( + encode_hex(node_id) for node_id in replacement_cache + ), + is_full=(len(bucket) >= self._routing_table.bucket_size), + ), + ) + + +class NodeInfoResponse(TypedDict): + node_id: HexStr + enr: str + + +class NodeInfoHandler(RPCHandler[None, NodeInfoResponse]): + _node_id_hex: HexStr + + def __init__(self, enr: ENRAPI) -> None: + self._enr = enr + + def extract_params(self, request: RPCRequest) -> None: + if request.get("params"): + raise RPCError(f"Unexpected RPC params: {request['params']}") + return None + + async def do_call(self, params: None) -> NodeInfoResponse: + return NodeInfoResponse( + node_id=encode_hex(self._enr.node_id), enr=repr(self._enr), + ) + + +@to_dict +def get_core_rpc_handlers( + enr: ENRAPI, routing_table: RoutingTableAPI +) -> Iterator[Tuple[str, RPCHandlerAPI]]: + yield ("discv5_routingTableInfo", RoutingTableInfoHandler(routing_table)) + yield ("discv5_nodeInfo", NodeInfoHandler(enr)) diff --git a/ddht/tools/factories/boot_info.py b/ddht/tools/factories/boot_info.py index 6fee3b10..1ecdf05d 100644 --- a/ddht/tools/factories/boot_info.py +++ b/ddht/tools/factories/boot_info.py @@ -30,3 +30,5 @@ class Meta: bootnodes = factory.LazyAttribute(lambda o: BOOTNODES[o.protocol_version]) is_ephemeral = False is_upnp_enabled = True + is_rpc_enabled = True + ipc_path = factory.LazyAttribute(lambda o: o.base_dir / "jsonrpc.ipc") diff --git a/ddht/tools/web3.py b/ddht/tools/web3.py new file mode 100644 index 00000000..810bb580 --- /dev/null +++ b/ddht/tools/web3.py @@ -0,0 +1,85 @@ +from typing import Callable, Mapping, NamedTuple, Tuple + +try: + import web3 # noqa: F401 +except ImportError: + raise ImportError("The web3.py library is required") + + +from eth_enr import ENR, ENRAPI +from eth_typing import NodeID +from eth_utils import decode_hex +from web3.method import Method +from web3.module import ModuleV2 +from web3.types import RPCEndpoint + +from ddht.rpc_handlers import BucketInfo as BucketInfoDict +from ddht.rpc_handlers import NodeInfoResponse, TableInfoResponse + + +class NodeInfo(NamedTuple): + node_id: NodeID + enr: ENRAPI + + @classmethod + def from_rpc_response(cls, response: NodeInfoResponse) -> "NodeInfo": + return cls( + node_id=NodeID(decode_hex(response["node_id"])), + enr=ENR.from_repr(response["enr"]), + ) + + +class BucketInfo(NamedTuple): + idx: int + nodes: Tuple[NodeID, ...] + replacement_cache: Tuple[NodeID, ...] + is_full: bool + + @classmethod + def from_rpc_response(cls, response: BucketInfoDict) -> "BucketInfo": + return cls( + idx=response["idx"], + nodes=tuple( + NodeID(decode_hex(node_id_hex)) for node_id_hex in response["nodes"] + ), + replacement_cache=tuple( + NodeID(decode_hex(node_id_hex)) + for node_id_hex in response["replacement_cache"] + ), + is_full=response["is_full"], + ) + + +class TableInfo(NamedTuple): + center_node_id: NodeID + num_buckets: int + bucket_size: int + buckets: Mapping[int, BucketInfo] + + @classmethod + def from_rpc_response(cls, response: TableInfoResponse) -> "TableInfo": + return cls( + center_node_id=NodeID(decode_hex(response["center_node_id"])), + num_buckets=response["num_buckets"], + bucket_size=response["bucket_size"], + buckets={ + int(idx): BucketInfo.from_rpc_response(bucket_stats) + for idx, bucket_stats in response["buckets"].items() + }, + ) + + +class RPC: + nodeInfo = RPCEndpoint("discv5_nodeInfo") + routingTableInfo = RPCEndpoint("discv5_routingTableInfo") + + +# TODO: why does mypy think ModuleV2 is of `Any` type? +class DiscoveryV5Module(ModuleV2): # type: ignore + get_node_info: Method[Callable[[], NodeInfo]] = Method( + RPC.nodeInfo, result_formatters=lambda method: NodeInfo.from_rpc_response, + ) + get_routing_table_info: Method[Callable[[], TableInfo]] = Method( + RPC.routingTableInfo, + result_formatters=lambda method: TableInfo.from_rpc_response, + ) diff --git a/ddht/typing.py b/ddht/typing.py index 37668e97..ee949957 100644 --- a/ddht/typing.py +++ b/ddht/typing.py @@ -1,5 +1,5 @@ import ipaddress -from typing import NamedTuple, NewType, Tuple, Union +from typing import Any, Dict, List, NamedTuple, NewType, Tuple, Union AES128Key = NewType("AES128Key", bytes) Nonce = NewType("Nonce", bytes) @@ -16,3 +16,5 @@ class SessionKeys(NamedTuple): AnyIPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] ENR_KV = Tuple[bytes, Union[int, bytes]] + +JSON = Union[Dict[Any, Any], str, int, List[Any]] diff --git a/ddht/v5/app.py b/ddht/v5/app.py index a1c8fcd5..2bb9e9d8 100644 --- a/ddht/v5/app.py +++ b/ddht/v5/app.py @@ -1,6 +1,3 @@ -import logging - -from async_service import Service, run_trio_service from eth.db.backends.level import LevelDB from eth_enr import ENRDB, ENRManager, default_identity_scheme_registry from eth_enr.exceptions import OldSequenceNumber @@ -9,6 +6,7 @@ import trio from ddht._utils import generate_node_key_file, read_node_key_file +from ddht.app import BaseApplication from ddht.base_message import AnyInboundMessage, AnyOutboundMessage from ddht.boot_info import BootInfo from ddht.constants import ( @@ -37,9 +35,6 @@ from ddht.v5.packer import Packer from ddht.v5.routing_table_manager import RoutingTableManager -logger = logging.getLogger("ddht.DDHT") - - ENR_DATABASE_DIR_NAME = "enr-db" @@ -54,13 +49,7 @@ def get_local_private_key(boot_info: BootInfo) -> keys.PrivateKey: return boot_info.private_key -class Application(Service): - logger = logger - _boot_info: BootInfo - - def __init__(self, boot_info: BootInfo) -> None: - self._boot_info = boot_info - +class Application(BaseApplication): async def run(self) -> None: identity_scheme_registry = default_identity_scheme_registry message_type_registry = v5_registry @@ -163,11 +152,11 @@ async def run(self) -> None: endpoint_vote_send_channel=endpoint_vote_channels[0], ) - logger.info(f"DDHT base dir: {self._boot_info.base_dir}") - logger.info("Starting discovery service...") - logger.info(f"Listening on {listen_on}:{port}") - logger.info(f"Local Node ID: {encode_hex(enr_manager.enr.node_id)}") - logger.info(f"Local ENR: {enr_manager.enr}") + self.logger.info(f"DDHT base dir: {self._boot_info.base_dir}") + self.logger.info("Starting discovery service...") + self.logger.info(f"Listening on {listen_on}:{port}") + self.logger.info(f"Local Node ID: {encode_hex(enr_manager.enr.node_id)}") + self.logger.info(f"Local ENR: {enr_manager.enr}") services = ( datagram_sender, @@ -181,6 +170,6 @@ async def run(self) -> None: ) await sock.bind((str(listen_on), port)) with sock: - async with trio.open_nursery() as nursery: - for service in services: - nursery.start_soon(run_trio_service, service) + for service in services: + self.manager.run_daemon_child_service(service) + await self.manager.wait_finished() diff --git a/ddht/v5_1/app.py b/ddht/v5_1/app.py index 764551a4..50e005bc 100644 --- a/ddht/v5_1/app.py +++ b/ddht/v5_1/app.py @@ -1,6 +1,3 @@ -import logging - -from async_service import Service, run_trio_service from eth.db.backends.level import LevelDB from eth_enr import ENRDB, ENRManager, default_identity_scheme_registry from eth_keys import keys @@ -8,9 +5,12 @@ import trio from ddht._utils import generate_node_key_file, read_node_key_file +from ddht.app import BaseApplication from ddht.boot_info import BootInfo from ddht.constants import DEFAULT_LISTEN, IP_V4_ADDRESS_ENR_KEY from ddht.endpoint import Endpoint +from ddht.rpc import RPCServer +from ddht.rpc_handlers import get_core_rpc_handlers from ddht.typing import AnyIPAddress from ddht.upnp import UPnPService from ddht.v5_1.client import Client @@ -18,9 +18,6 @@ from ddht.v5_1.messages import v51_registry from ddht.v5_1.network import Network -logger = logging.getLogger("ddht.DDHT") - - ENR_DATABASE_DIR_NAME = "enr-db" @@ -35,13 +32,7 @@ def get_local_private_key(boot_info: BootInfo) -> keys.PrivateKey: return boot_info.private_key -class Application(Service): - logger = logger - _boot_info: BootInfo - - def __init__(self, boot_info: BootInfo) -> None: - self._boot_info = boot_info - +class Application(BaseApplication): async def _update_enr_ip_from_upnp( self, enr_manager: ENRManager, upnp_service: UPnPService ) -> None: @@ -102,13 +93,19 @@ async def run(self) -> None: ) network = Network(client=client, bootnodes=bootnodes,) - logger.info("Protocol-Version: %s", self._boot_info.protocol_version.value) - logger.info("DDHT base dir: %s", self._boot_info.base_dir) - logger.info("Starting discovery service...") - logger.info("Listening on %s:%d", listen_on, port) - logger.info("Local Node ID: %s", encode_hex(enr_manager.enr.node_id)) - logger.info( + if self._boot_info.is_rpc_enabled: + handlers = get_core_rpc_handlers(network.routing_table) + rpc_server = RPCServer(self._boot_info.ipc_path, handlers) + self.manager.run_daemon_child_service(rpc_server) + + self.logger.info("Protocol-Version: %s", self._boot_info.protocol_version.value) + self.logger.info("DDHT base dir: %s", self._boot_info.base_dir) + self.logger.info("Starting discovery service...") + self.logger.info("Listening on %s:%d", listen_on, port) + self.logger.info("Local Node ID: %s", encode_hex(enr_manager.enr.node_id)) + self.logger.info( "Local ENR: seq=%d enr=%s", enr_manager.enr.sequence_number, enr_manager.enr ) - await run_trio_service(network) + self.manager.run_daemon_child_service(network) + self.manager.run_daemon_child_service(rpc_server) diff --git a/setup.py b/setup.py index b213bd84..efbaefca 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "twine", "ipython", ], + "web3": ["web3>=5.12.1,<6"], } extras_require["dev"] = ( diff --git a/tests/conftest.py b/tests/conftest.py index 2977cbae..456eade9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,11 @@ def xdg_home(monkeypatch): yield pathlib.Path(temp_xdg) +@pytest.fixture +def ipc_path(xdg_home): + return xdg_home / "jsonrpc.ipc" + + @pytest_trio.trio_fixture async def socket_pair(): sending_socket = trio.socket.socket( diff --git a/tests/core/test_core_rpc_handlers.py b/tests/core/test_core_rpc_handlers.py new file mode 100644 index 00000000..ee351450 --- /dev/null +++ b/tests/core/test_core_rpc_handlers.py @@ -0,0 +1,215 @@ +import io +import itertools +import json + +from async_service import background_trio_service +from eth_enr.tools.factories import ENRFactory +from eth_utils import decode_hex +import pytest +import trio +from web3 import IPCProvider, Web3 + +from ddht.constants import ROUTING_TABLE_BUCKET_SIZE +from ddht.kademlia import KademliaRoutingTable +from ddht.rpc import RPCRequest, RPCServer, read_json +from ddht.rpc_handlers import get_core_rpc_handlers +from ddht.tools.factories.node_id import NodeIDFactory +from ddht.tools.web3 import DiscoveryV5Module + + +@pytest.fixture +def enr(): + return ENRFactory() + + +@pytest.fixture +def routing_table(enr): + return KademliaRoutingTable(enr.node_id, ROUTING_TABLE_BUCKET_SIZE) + + +@pytest.fixture +async def rpc_server(ipc_path, routing_table, enr): + server = RPCServer(ipc_path, get_core_rpc_handlers(enr, routing_table)) + async with background_trio_service(server): + await server.wait_serving() + yield server + + +@pytest.fixture +def w3(rpc_server, ipc_path): + return Web3(IPCProvider(ipc_path), modules={"discv5": (DiscoveryV5Module,)}) + + +@pytest.fixture(name="make_request") +async def _make_request(ipc_path, rpc_server): + socket = await trio.open_unix_socket(str(ipc_path)) + async with socket: + buffer = io.StringIO() + id_counter = itertools.count() + + async def make_request(method, params=None): + if params is None: + params = [] + request = RPCRequest( + jsonrpc="2.0", method=method, params=params, id=next(id_counter), + ) + raw_request = json.dumps(request) + + with trio.fail_after(2): + await socket.send_all(raw_request.encode("utf8")) + raw_response = await read_json(socket.socket, buffer) + + if "error" in raw_response: + raise Exception(raw_response) + elif "result" in raw_response: + return raw_response["result"] + else: + raise Exception("Invariant") + + yield make_request + + +@pytest.mark.trio +async def test_rpc_nodeInfo(make_request, enr): + node_info = await make_request("discv5_nodeInfo") + assert decode_hex(node_info["node_id"]) == enr.node_id + assert node_info["enr"] == repr(enr) + + +@pytest.mark.trio +async def test_rpc_nodeInfo_web3(w3, enr, rpc_server): + with trio.fail_after(2): + node_info = await trio.to_thread.run_sync(w3.discv5.get_node_info) + assert node_info.node_id == enr.node_id + assert node_info.enr == enr + + +@pytest.mark.trio +async def test_rpc_tableInfo(make_request, routing_table): + local_node_id = routing_table.center_node_id + # 16/16 at furthest distance + for _ in range(routing_table.bucket_size * 2): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 256)) + # 16/8 at next bucket + for _ in range(int(routing_table.bucket_size * 1.5)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 255)) + # 16/4 at next bucket + for _ in range(int(routing_table.bucket_size * 1.25)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 254)) + # 16 in this one + for _ in range(int(routing_table.bucket_size)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 253)) + # 8 in this one + for _ in range(int(routing_table.bucket_size // 2)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 252)) + # 4 in this one + for _ in range(int(routing_table.bucket_size // 4)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 251)) + + table_info = await make_request("discv5_routingTableInfo") + assert decode_hex(table_info["center_node_id"]) == routing_table.center_node_id + assert table_info["bucket_size"] == routing_table.bucket_size + assert table_info["num_buckets"] == routing_table.num_buckets + assert len(table_info["buckets"]) == 6 + + bucket_256 = table_info["buckets"]["256"] + bucket_255 = table_info["buckets"]["255"] + bucket_254 = table_info["buckets"]["254"] + bucket_253 = table_info["buckets"]["253"] + bucket_252 = table_info["buckets"]["252"] + bucket_251 = table_info["buckets"]["251"] + + assert bucket_256["idx"] == 256 + assert bucket_256["is_full"] is True + assert len(bucket_256["nodes"]) == routing_table.bucket_size + assert len(bucket_256["replacement_cache"]) == routing_table.bucket_size + + assert bucket_255["idx"] == 255 + assert bucket_255["is_full"] is True + assert len(bucket_255["nodes"]) == routing_table.bucket_size + assert len(bucket_255["replacement_cache"]) == routing_table.bucket_size // 2 + + assert bucket_254["idx"] == 254 + assert bucket_254["is_full"] is True + assert len(bucket_254["nodes"]) == routing_table.bucket_size + assert len(bucket_254["replacement_cache"]) == routing_table.bucket_size // 4 + + assert bucket_253["idx"] == 253 + assert bucket_253["is_full"] is True + assert len(bucket_253["nodes"]) == routing_table.bucket_size + assert not bucket_253["replacement_cache"] + + assert bucket_252["idx"] == 252 + assert bucket_252["is_full"] is False + assert len(bucket_252["nodes"]) == routing_table.bucket_size // 2 + assert not bucket_252["replacement_cache"] + + assert bucket_251["idx"] == 251 + assert bucket_251["is_full"] is False + assert len(bucket_251["nodes"]) == routing_table.bucket_size // 4 + assert not bucket_251["replacement_cache"] + + +@pytest.mark.trio +async def test_rpc_tableInfo_web3(w3, routing_table, rpc_server): + local_node_id = routing_table.center_node_id + # 16/16 at furthest distance + for _ in range(routing_table.bucket_size * 2): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 256)) + # 16/8 at next bucket + for _ in range(int(routing_table.bucket_size * 1.5)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 255)) + # 16/4 at next bucket + for _ in range(int(routing_table.bucket_size * 1.25)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 254)) + # 16 in this one + for _ in range(int(routing_table.bucket_size)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 253)) + # 8 in this one + for _ in range(int(routing_table.bucket_size // 2)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 252)) + # 4 in this one + for _ in range(int(routing_table.bucket_size // 4)): + routing_table.update(NodeIDFactory.at_log_distance(local_node_id, 251)) + + table_info = await trio.to_thread.run_sync(w3.discv5.get_routing_table_info) + assert table_info.center_node_id == routing_table.center_node_id + assert table_info.bucket_size == routing_table.bucket_size + assert table_info.num_buckets == routing_table.num_buckets + assert len(table_info.buckets) == 6 + bucket_256 = table_info.buckets[256] + bucket_255 = table_info.buckets[255] + bucket_254 = table_info.buckets[254] + bucket_253 = table_info.buckets[253] + bucket_252 = table_info.buckets[252] + bucket_251 = table_info.buckets[251] + + assert bucket_256.idx == 256 + assert bucket_256.is_full is True + assert len(bucket_256.nodes) == routing_table.bucket_size + assert len(bucket_256.replacement_cache) == routing_table.bucket_size + + assert bucket_255.idx == 255 + assert bucket_255.is_full is True + assert len(bucket_255.nodes) == routing_table.bucket_size + assert len(bucket_255.replacement_cache) == routing_table.bucket_size // 2 + + assert bucket_254.idx == 254 + assert bucket_254.is_full is True + assert len(bucket_254.nodes) == routing_table.bucket_size + assert len(bucket_254.replacement_cache) == routing_table.bucket_size // 4 + + assert bucket_253.idx == 253 + assert bucket_253.is_full is True + assert len(bucket_253.nodes) == routing_table.bucket_size + assert not bucket_253.replacement_cache + + assert bucket_252.idx == 252 + assert bucket_252.is_full is False + assert len(bucket_252.nodes) == routing_table.bucket_size // 2 + assert not bucket_253.replacement_cache + + assert bucket_251.idx == 251 + assert bucket_251.is_full is False + assert len(bucket_251.nodes) == routing_table.bucket_size // 4 + assert not bucket_253.replacement_cache diff --git a/tests/core/v5_1/test_network.py b/tests/core/v5_1/test_network.py index 7f7acb50..39370ace 100644 --- a/tests/core/v5_1/test_network.py +++ b/tests/core/v5_1/test_network.py @@ -190,6 +190,7 @@ async def test_network_pings_oldest_routing_table(tester, alice, bob, autojump_c trio.current_time() - ROUTING_TABLE_KEEP_ALIVE - 1 ) await trio.sleep(ROUTING_TABLE_KEEP_ALIVE) + await trio.sleep(ROUTING_TABLE_KEEP_ALIVE) assert alice_network.routing_table._contains(bob.node_id, False) assert not alice_network.routing_table._contains(carol.node_id, False) @@ -201,6 +202,7 @@ async def test_network_pings_oldest_routing_table(tester, alice, bob, autojump_c trio.current_time() - ROUTING_TABLE_KEEP_ALIVE - 1 ) await trio.sleep(ROUTING_TABLE_KEEP_ALIVE) + await trio.sleep(ROUTING_TABLE_KEEP_ALIVE) assert not alice_network.routing_table._contains(bob.node_id, False) assert not alice_network.routing_table._contains(carol.node_id, False) diff --git a/tox.ini b/tox.ini index 7a30ea3a..be792674 100644 --- a/tox.ini +++ b/tox.ini @@ -33,6 +33,7 @@ basepython = py38: python3.8 extras= test + web3 docs: doc whitelist_externals=make From a1e1e52b9dd02e4ef99811e7cc5918b769e4ca79 Mon Sep 17 00:00:00 2001 From: Piper Merriam Date: Wed, 23 Sep 2020 18:12:41 -0600 Subject: [PATCH 2/2] PR feedback --- ddht/rpc.py | 96 +++++++++++++++------------- ddht/rpc_handlers.py | 2 +- tests/core/test_core_rpc_handlers.py | 93 +++++++++++++++++++++------ 3 files changed, 124 insertions(+), 67 deletions(-) diff --git a/ddht/rpc.py b/ddht/rpc.py index f36d9889..a6297f53 100644 --- a/ddht/rpc.py +++ b/ddht/rpc.py @@ -4,13 +4,14 @@ import json import logging import pathlib -from typing import Any, Dict, Generic, Mapping, Optional, Tuple, TypeVar, cast +from typing import Any, Dict, Generic, Mapping, Tuple, TypeVar, cast from async_service import Service from eth_utils import ValidationError import trio from ddht.abc import RPCHandlerAPI, RPCRequest, RPCResponse +from ddht.exceptions import DecodingError from ddht.typing import JSON NEW_LINE = "\n" @@ -28,6 +29,9 @@ def strip_non_json_prefix(raw_request: str) -> Tuple[str, str]: decoder = json.JSONDecoder() +MAXIMUM_RPC_PAYLOAD_SIZE = 1024 * 1024 # 1 MB + + async def read_json( socket: trio.socket.SocketType, buffer: io.StringIO, @@ -43,7 +47,15 @@ async def read_json( if bad_prefix: logger.info("Client started request with non json data: %r", bad_prefix) await write_error(socket, f"Cannot parse json: {bad_prefix}") - continue + raise DecodingError(f"Invalid JSON payload: prefix={bad_prefix}") + + if len(raw_request) > MAXIMUM_RPC_PAYLOAD_SIZE: + error_msg = ( + f"RPC payload exceeds maximum size: {len(raw_request)} " + f"> {MAXIMUM_RPC_PAYLOAD_SIZE}" + ) + await write_error(socket, error_msg) + raise DecodingError(error_msg) try: request, offset = decoder.raw_decode(raw_request) @@ -57,9 +69,6 @@ async def read_json( await trio.sleep(0.01) continue - # TODO: more efficient algorithm can be used here by - # manipulating the buffer such that we can seek back to the - # correct position for *new* data to come in. buffer.seek(0) buffer.write(raw_request[offset:]) buffer.truncate() @@ -87,28 +96,24 @@ def validate_request(request: Mapping[Any, Any]) -> None: raise ValidationError("Missing 'method' key") if "params" in request: if not isinstance(request["params"], list): - raise ValidationError("Missing 'method' key") + raise ValidationError( + f"The `params` value must be a list. Got: {type(request['params'])}" + ) -def generate_response( - request: RPCRequest, result: Any, error: Optional[str] -) -> RPCResponse: +def generate_error_response(request: RPCRequest, error: str) -> RPCResponse: response = RPCResponse( - id=request.get("id", -1), jsonrpc=request.get("jsonrpc", "2.0"), + id=request.get("id", -1), + jsonrpc=request.get("jsonrpc", "2.0"), + error=str(error), ) + return response - if result is None and error is None: - raise TypeError("Must supply either result or error for JSON-RPC response") - if result is not None and error is not None: - raise TypeError( - "Must not supply both a result and an error for JSON-RPC response" - ) - elif result is not None: - response["result"] = result - elif error is not None: - response["error"] = str(error) - else: - raise Exception("Unreachable code path") + +def generate_success_response(request: RPCRequest, result: Any,) -> RPCResponse: + response = RPCResponse( + id=request.get("id", -1), jsonrpc=request.get("jsonrpc", "2.0"), result=result, + ) return response @@ -122,35 +127,38 @@ class RPCError(Exception): class RPCHandler(RPCHandlerAPI, Generic[TParams, TResult]): + """ + Class to simplify some boilerplate when writing an RPCHandlerAPI + implementation. + """ + async def __call__(self, request: RPCRequest) -> RPCResponse: try: params = self.extract_params(request) result = await self.do_call(params) except RPCError as err: - return self.generate_error_response(request, str(err)) + return generate_error_response(request, str(err)) else: - return self.generate_success_response(request, result) + return generate_success_response(request, result) @abstractmethod async def do_call(self, params: TParams) -> TResult: + """ + The return value of this function will be used as the `result` key in a + success response. To return an error response, raise an + :class:`ddht.rpc.RPCError` which will be used as the `error` key in the + response. + """ ... def extract_params(self, request: RPCRequest) -> TParams: return request.get("params", []) # type: ignore - def generate_success_response( - self, request: RPCRequest, result: Any - ) -> RPCResponse: - return generate_response(request, result, None) - - def generate_error_response(self, request: RPCRequest, error: str) -> RPCResponse: - return generate_response(request, None, error) - class UnknownMethodHandler(RPCHandlerAPI): async def __call__(self, request: RPCRequest) -> RPCResponse: - return generate_response( - request, None, f"Unknown RPC method: {request['method']}" + return generate_error_response( + request, f"Unknown RPC method: {request['method']}" ) @@ -176,7 +184,7 @@ async def run(self) -> None: try: await self.manager.wait_finished() finally: - self.ipc_path.unlink() + self.ipc_path.unlink(missing_ok=True) async def execute_rpc(self, request: RPCRequest) -> str: method = request["method"] @@ -189,7 +197,7 @@ async def execute_rpc(self, request: RPCRequest) -> str: except Exception as err: self.logger.error("Error handling request: %s error: %s", request, err) self.logger.debug("Error handling request: %s", request, exc_info=True) - response = generate_response(request, None, f"Unexpected Error: {err}") + response = generate_error_response(request, f"Unexpected Error: {err}") finally: return json.dumps(response) @@ -197,14 +205,8 @@ async def serve(self, ipc_path: pathlib.Path) -> None: self.logger.info("Starting RPC server over IPC socket: %s", ipc_path) with trio.socket.socket(trio.socket.AF_UNIX, trio.socket.SOCK_STREAM) as sock: - # TODO: unclear if the following stuff is necessary: - # ################################################### - # These options help fix an issue with the socket reporting itself - # already being used since it accepts many client connection. - # https://stackoverflow.com/questions/6380057/python-binding-socket-address-already-in-use - # sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # ################################################### await sock.bind(str(ipc_path)) + # Allow up to 10 pending connections. sock.listen(10) @@ -220,7 +222,11 @@ async def _handle_connection(self, socket: trio.socket.SocketType) -> None: with socket: while True: - request = await read_json(socket, buffer) + try: + request = await read_json(socket, buffer) + except DecodingError: + # If the connection receives bad JSON, close the connection. + return if not isinstance(request, collections.abc.Mapping): logger.debug("Invalid payload: %s", type(request)) @@ -248,6 +254,6 @@ async def _handle_connection(self, socket: trio.socket.SocketType) -> None: result += NEW_LINE try: - await socket.send(result.encode()) + await socket.send(result.encode("utf8")) except BrokenPipeError: break diff --git a/ddht/rpc_handlers.py b/ddht/rpc_handlers.py index f8ae9b7c..4cfd041d 100644 --- a/ddht/rpc_handlers.py +++ b/ddht/rpc_handlers.py @@ -46,7 +46,7 @@ def _bucket_stats(self) -> Iterator[Tuple[int, BucketInfo]]: self._routing_table.buckets, self._routing_table.replacement_caches, ) for idx, (bucket, replacement_cache) in enumerate( - buckets_and_replacement_caches, 1 + buckets_and_replacement_caches, start=1 ): if bucket: yield ( diff --git a/tests/core/test_core_rpc_handlers.py b/tests/core/test_core_rpc_handlers.py index ee351450..6c115abf 100644 --- a/tests/core/test_core_rpc_handlers.py +++ b/tests/core/test_core_rpc_handlers.py @@ -5,13 +5,14 @@ from async_service import background_trio_service from eth_enr.tools.factories import ENRFactory from eth_utils import decode_hex +from eth_utils.toolz import take import pytest import trio from web3 import IPCProvider, Web3 from ddht.constants import ROUTING_TABLE_BUCKET_SIZE from ddht.kademlia import KademliaRoutingTable -from ddht.rpc import RPCRequest, RPCServer, read_json +from ddht.rpc import MAXIMUM_RPC_PAYLOAD_SIZE, RPCRequest, RPCServer, read_json from ddht.rpc_handlers import get_core_rpc_handlers from ddht.tools.factories.node_id import NodeIDFactory from ddht.tools.web3 import DiscoveryV5Module @@ -40,33 +41,52 @@ def w3(rpc_server, ipc_path): return Web3(IPCProvider(ipc_path), modules={"discv5": (DiscoveryV5Module,)}) -@pytest.fixture(name="make_request") -async def _make_request(ipc_path, rpc_server): +@pytest.fixture(name="make_raw_request") +async def _make_raw_request(ipc_path, rpc_server): socket = await trio.open_unix_socket(str(ipc_path)) async with socket: buffer = io.StringIO() - id_counter = itertools.count() - - async def make_request(method, params=None): - if params is None: - params = [] - request = RPCRequest( - jsonrpc="2.0", method=method, params=params, id=next(id_counter), - ) - raw_request = json.dumps(request) + async def make_raw_request(raw_request: str): with trio.fail_after(2): - await socket.send_all(raw_request.encode("utf8")) - raw_response = await read_json(socket.socket, buffer) + data = raw_request.encode("utf8") + data_iter = iter(data) + while True: + chunk = bytes(take(1024, data_iter)) + if chunk: + try: + await socket.send_all(chunk) + except trio.BrokenResourceError: + break + else: + break + return await read_json(socket.socket, buffer) + + yield make_raw_request + + +@pytest.fixture(name="make_request") +async def _make_request(make_raw_request): + id_counter = itertools.count() + + async def make_request(method, params=None): + if params is None: + params = [] + request = RPCRequest( + jsonrpc="2.0", method=method, params=params, id=next(id_counter), + ) + raw_request = json.dumps(request) + + raw_response = await make_raw_request(raw_request) - if "error" in raw_response: - raise Exception(raw_response) - elif "result" in raw_response: - return raw_response["result"] - else: - raise Exception("Invariant") + if "error" in raw_response: + raise Exception(raw_response) + elif "result" in raw_response: + return raw_response["result"] + else: + raise Exception("Invariant") - yield make_request + yield make_request @pytest.mark.trio @@ -76,6 +96,37 @@ async def test_rpc_nodeInfo(make_request, enr): assert node_info["enr"] == repr(enr) +@pytest.mark.parametrize( + "raw_request", ("just-a-raw-string",), +) +@pytest.mark.trio +async def test_rpc_closes_connection_on_bad_data(make_raw_request, raw_request): + response = await make_raw_request(raw_request) + assert "error" in response + + with pytest.raises(ConnectionResetError): + try: + await make_raw_request("should-not-work") + except trio.TooSlowError as err: + raise ConnectionResetError(str(err)) + + +@pytest.mark.parametrize( + "raw_request", ("just-a-raw-string",), +) +@pytest.mark.trio +async def test_rpc_closes_connection_on_too_large_data(make_raw_request, raw_request): + too_long_string = "too-long-string:" + "0" * MAXIMUM_RPC_PAYLOAD_SIZE + response = await make_raw_request(json.dumps({"key": too_long_string})) + assert "error" in response + + with pytest.raises(ConnectionResetError): + try: + await make_raw_request("should-not-work") + except trio.TooSlowError as err: + raise ConnectionResetError(str(err)) + + @pytest.mark.trio async def test_rpc_nodeInfo_web3(w3, enr, rpc_server): with trio.fail_after(2):