From 82c9605938c92da5ddf76cb749f1a6adbb4cf22a Mon Sep 17 00:00:00 2001 From: Marcelo Salhab Brogliato Date: Thu, 11 Jul 2024 12:51:27 -0500 Subject: [PATCH] feat(wallet): Add a vertex history streamer to the wallet websocket API --- hathor/builder/resources_builder.py | 3 +- hathor/websocket/factory.py | 62 +++---- hathor/websocket/protocol.py | 119 ++++++++++++- hathor/websocket/streamer.py | 254 ++++++++++++++++++++++++++++ 4 files changed, 401 insertions(+), 37 deletions(-) create mode 100644 hathor/websocket/streamer.py diff --git a/hathor/builder/resources_builder.py b/hathor/builder/resources_builder.py index 88c38ff98..4454dcf40 100644 --- a/hathor/builder/resources_builder.py +++ b/hathor/builder/resources_builder.py @@ -261,7 +261,8 @@ def create_resources(self) -> server.Site: # Websocket resource assert self.manager.tx_storage.indexes is not None - ws_factory = HathorAdminWebsocketFactory(metrics=self.manager.metrics, + ws_factory = HathorAdminWebsocketFactory(manager=self.manager, + metrics=self.manager.metrics, address_index=self.manager.tx_storage.indexes.addresses) ws_factory.start() root.putChild(b'ws', WebSocketResource(ws_factory)) diff --git a/hathor/websocket/factory.py b/hathor/websocket/factory.py index 2c7aa2d16..1d0e592ba 100644 --- a/hathor/websocket/factory.py +++ b/hathor/websocket/factory.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import defaultdict, deque -from typing import Any, Optional, Union +from typing import Any, Optional from autobahn.exception import Disconnected from autobahn.twisted.websocket import WebSocketServerFactory @@ -22,11 +22,12 @@ from hathor.conf import HathorSettings from hathor.indexes import AddressIndex +from hathor.manager import HathorManager from hathor.metrics import Metrics from hathor.p2p.rate_limiter import RateLimiter from hathor.pubsub import EventArguments, HathorEvents from hathor.reactor import get_global_reactor -from hathor.util import json_dumpb, json_loadb, json_loads +from hathor.util import json_dumpb from hathor.websocket.protocol import HathorAdminWebsocketProtocol settings = HathorSettings() @@ -85,11 +86,15 @@ class HathorAdminWebsocketFactory(WebSocketServerFactory): def buildProtocol(self, addr): return self.protocol(self) - def __init__(self, metrics: Optional[Metrics] = None, address_index: Optional[AddressIndex] = None): + def __init__(self, + manager: HathorManager, + metrics: Optional[Metrics] = None, + address_index: Optional[AddressIndex] = None): """ :param metrics: If not given, a new one is created. :type metrics: :py:class:`hathor.metrics.Metrics` """ + self.manager = manager self.reactor = get_global_reactor() # Opened websocket connections so I can broadcast messages later # It contains only connections that have finished handshaking. @@ -300,44 +305,33 @@ def process_deque(self, data_type): data_type=data_type) break - def handle_message(self, connection: HathorAdminWebsocketProtocol, data: Union[bytes, str]) -> None: - """ General message handler, detects type and deletages to specific handler.""" - if isinstance(data, bytes): - message = json_loadb(data) - else: - message = json_loads(data) - # we only handle ping messages for now - if message['type'] == 'ping': - self._handle_ping(connection, message) - elif message['type'] == 'subscribe_address': - self._handle_subscribe_address(connection, message) - elif message['type'] == 'unsubscribe_address': - self._handle_unsubscribe_address(connection, message) - - def _handle_ping(self, connection: HathorAdminWebsocketProtocol, message: dict[Any, Any]) -> None: - """ Handler for ping message, should respond with a simple {"type": "pong"}""" - payload = json_dumpb({'type': 'pong'}) - connection.sendMessage(payload, False) - def _handle_subscribe_address(self, connection: HathorAdminWebsocketProtocol, message: dict[Any, Any]) -> None: """ Handler for subscription to an address, consideirs subscription limits.""" - addr: str = message['address'] + address: str = message['address'] + success, errmsg = self.subscribe_address(connection, address) + response = { + 'type': 'subscribe_address', + 'address': address, + 'success': success, + } + if not success: + response['message'] = errmsg + connection.sendMessage(json_dumpb(response), False) + + def subscribe_address(self, connection: HathorAdminWebsocketProtocol, address: str) -> tuple[bool, str]: + """Subscribe an address to send real time updates to a websocket connection.""" subs: set[str] = connection.subscribed_to if self.max_subs_addrs_conn is not None and len(subs) >= self.max_subs_addrs_conn: - payload = json_dumpb({'message': 'Reached maximum number of subscribed ' - f'addresses ({self.max_subs_addrs_conn}).', - 'type': 'subscribe_address', 'success': False}) + return False, f'Reached maximum number of subscribed addresses ({self.max_subs_addrs_conn}).' + elif self.max_subs_addrs_empty is not None and ( self.address_index and _count_empty(subs, self.address_index) >= self.max_subs_addrs_empty ): - payload = json_dumpb({'message': 'Reached maximum number of subscribed ' - f'addresses without output ({self.max_subs_addrs_empty}).', - 'type': 'subscribe_address', 'success': False}) - else: - self.address_connections[addr].add(connection) - connection.subscribed_to.add(addr) - payload = json_dumpb({'type': 'subscribe_address', 'success': True}) - connection.sendMessage(payload, False) + return False, f'Reached maximum number of subscribed empty addresses ({self.max_subs_addrs_empty}).' + + self.address_connections[address].add(connection) + connection.subscribed_to.add(address) + return True, '' def _handle_unsubscribe_address(self, connection: HathorAdminWebsocketProtocol, message: dict[Any, Any]) -> None: """ Handler for unsubscribing from an address, also removes address connection set if it ends up empty.""" diff --git a/hathor/websocket/protocol.py b/hathor/websocket/protocol.py index 5429b9506..6772335f3 100644 --- a/hathor/websocket/protocol.py +++ b/hathor/websocket/protocol.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Any, Union from autobahn.twisted.websocket import WebSocketServerProtocol from structlog import get_logger +from hathor.util import json_dumpb, json_loadb, json_loads +from hathor.websocket.streamer import GAPLimitSearch, HistoryStreamer, ManualAddressSequencer, XPubAddressSequencer + if TYPE_CHECKING: from hathor.websocket.factory import HathorAdminWebsocketFactory @@ -30,23 +33,135 @@ class HathorAdminWebsocketProtocol(WebSocketServerProtocol): can send the data update to the clients """ + MAX_GAP_LIMIT: int = 30 + def __init__(self, factory: 'HathorAdminWebsocketFactory') -> None: self.log = logger.new() self.factory = factory self.subscribed_to: set[str] = set() + self._history_streamer: HistoryStreamer | None = None + self._manual_address_iter: ManualAddressSequencer | None = None super().__init__() def onConnect(self, request): + """Called by the websocket protocol when the connection is opened but it is still pending handshaking.""" self.log.info('connection opened, starting handshake...', request=request) def onOpen(self) -> None: + """Called by the websocket protocol when the connection is established.""" self.factory.on_client_open(self) self.log.info('connection established') def onClose(self, wasClean, code, reason): + """Called by the websocket protocol when the connection is closed.""" self.factory.on_client_close(self) self.log.info('connection closed', reason=reason) def onMessage(self, payload: Union[bytes, str], isBinary: bool) -> None: + """Called by the websocket protocol when a new message is received.""" self.log.debug('new message', payload=payload.hex() if isinstance(payload, bytes) else payload) - self.factory.handle_message(self, payload) + if isinstance(payload, bytes): + message = json_loadb(payload) + else: + message = json_loads(payload) + + _type = message.get('type') + + if _type == 'ping': + self._handle_ping(message) + elif _type == 'subscribe_address': + self.factory._handle_subscribe_address(self, message) + elif _type == 'unsubscribe_address': + self.factory._handle_unsubscribe_address(self, message) + elif _type == 'request:history:xpub': + self._open_history_xpub_streamer(message) + elif _type == 'request:history:manual': + self._open_history_manual_streamer(message) + + def _handle_ping(self, message: dict[Any, Any]) -> None: + """Handle ping message, should respond with a simple {"type": "pong"}""" + payload = json_dumpb({'type': 'pong'}) + self.sendMessage(payload, False) + + def _open_history_xpub_streamer(self, message: dict[Any, Any]) -> None: + """Handle request to stream transactions using an xpub.""" + stream_id = message['id'] + + if self._history_streamer is not None: + self.sendMessage(json_dumpb({ + 'type': 'stream:history:error', + 'id': stream_id, + 'errmsg': 'Streaming is already opened.' + })) + return + + xpub = message['xpub'] + gap_limit = message.get('gap-limit', 20) + first_index = message.get('first-index', 0) + if gap_limit > self.MAX_GAP_LIMIT: + self.sendMessage(json_dumpb({ + 'type': 'stream:history:error', + 'id': stream_id, + 'errmsg': f'GAP limit is too big. Maximum: {self.MAX_GAP_LIMIT}' + })) + return + + address_iter = XPubAddressSequencer(xpub, first_index=first_index) + search = GAPLimitSearch(self.factory.manager, address_iter, gap_limit) + self._history_streamer = HistoryStreamer(protocol=self, stream_id=stream_id, search=search) + deferred = self._history_streamer.start() + deferred.addBoth(self._streamer_callback) + + def _open_history_manual_streamer(self, message: dict[Any, Any]) -> None: + """Handle request to stream transactions using a list of addresses.""" + stream_id = message['id'] + addresses = message.get('addresses', []) + first = message.get('first', False) + last = message.get('last', False) + + if self._history_streamer is not None: + if first or self._history_streamer.stream_id != stream_id: + self.sendMessage(json_dumpb({ + 'type': 'stream:history:error', + 'id': stream_id, + 'errmsg': 'Streaming is already opened.' + })) + return + + assert self._manual_address_iter is not None + self._manual_address_iter.add_addresses(addresses, last) + return + + gap_limit = message.get('gap-limit', 20) + if gap_limit > self.MAX_GAP_LIMIT: + self.sendMessage(json_dumpb({ + 'type': 'stream:history:error', + 'id': stream_id, + 'errmsg': f'GAP limit is too big. Maximum: {self.MAX_GAP_LIMIT}' + })) + return + + if not first: + self.sendMessage(json_dumpb({ + 'type': 'stream:history:error', + 'id': stream_id, + 'errmsg': 'Streaming not found. You must send first=true in your first message.' + })) + return + + address_iter = ManualAddressSequencer() + address_iter.add_addresses(enumerate(addresses), last) + search = GAPLimitSearch(self.factory.manager, address_iter, gap_limit) + self._manual_address_iter = address_iter + self._history_streamer = HistoryStreamer(protocol=self, stream_id=stream_id, search=search) + deferred = self._history_streamer.start() + deferred.addBoth(self._streamer_callback) + + def _streamer_callback(self, success: bool) -> None: + """Callback used to identify when the streamer has ended.""" + self._history_streamer = None + self._manual_address_iter = None + + def subscribe_address(self, address: str) -> tuple[bool, str]: + """Subscribe to receive real-time messages for all vertices related to an address.""" + return self.factory.subscribe_address(self, address) diff --git a/hathor/websocket/streamer.py b/hathor/websocket/streamer.py new file mode 100644 index 000000000..622834d2a --- /dev/null +++ b/hathor/websocket/streamer.py @@ -0,0 +1,254 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import deque +from collections.abc import AsyncIterable +from typing import TYPE_CHECKING, Any, AsyncIterator + +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IPushProducer +from twisted.internet.task import deferLater +from zope.interface import implementer + +from hathor.manager import HathorManager +from hathor.types import AddressB58 +from hathor.util import json_dumpb + +if TYPE_CHECKING: + from hathor.websocket.protocol import HathorAdminWebsocketProtocol + + +class ManualAddressSequencer(AsyncIterable[tuple[int, AddressB58]]): + """An async iterable that yields addresses from a list. More addresses + can be added while the iterator is being consumed. + """ + + def __init__(self) -> None: + self.pending_addresses: deque[tuple[int, AddressB58]] = deque() + self.await_items: Deferred | None = None + + # Flag to mark when all addresses have been received so the iterator + # can stop yielding after the pending list of addresses is empty. + self._stop = False + + def _resume_iter(self) -> None: + """Resume yield addresses.""" + if self.await_items is None: + return + if not self.await_items.called: + self.await_items.callback(None) + + def add_addresses(self, addresses: list[tuple[int, AddressB58]], last: bool) -> None: + """Add more addresses to be yielded. If `last` is true, the iterator + will stop when the pending list of items gets empty.""" + self.pending_addresses.extend(addresses) + if last: + self._stop = True + self._resume_iter() + + def __aiter__(self) -> AsyncIterator[tuple[int, AddressB58]]: + """Return an async iterator.""" + return self._async_iter() + + async def _async_iter(self) -> AsyncIterator[tuple[int, AddressB58]]: + """Internal method that implements the async iterator.""" + while True: + while self.pending_addresses: + idx, address = self.pending_addresses.popleft() + yield (idx, address) + + if self._stop: + break + + self.await_items = Deferred() + await self.await_items + + +class XPubAddressSequencer(AsyncIterable[tuple[int, AddressB58]]): + """An async iterable that yields addresses derived from an xpub. + """ + def __init__(self, xpub: str, *, first_index: int = 0) -> None: + from pycoin.networks.registry import network_for_netcode + + from hathor.wallet.hd_wallet import _register_pycoin_networks + _register_pycoin_networks() + network = network_for_netcode('htr') + + self.xpub = network.parse.bip32(xpub) + self.first_index = first_index + + def __aiter__(self) -> AsyncIterator[tuple[int, AddressB58]]: + """Return an async iterator.""" + return self._async_iter() + + async def _async_iter(self) -> AsyncIterator[tuple[int, AddressB58]]: + """Internal method that implements the async iterator.""" + idx = self.first_index + while True: + key = self.xpub.subkey(idx) + yield (idx, AddressB58(key.address())) + idx += 1 + + +class GAPLimitSearch(AsyncIterable[tuple[str, Any]]): + """An async iterable that yields addresses and vertices, stopping when the gap limit is reached. + """ + def __init__(self, manager: HathorManager, address_iter: AsyncIterable[tuple[int, AddressB58]], gap_limit: int): + self.manager = manager + self.address_iter = address_iter + self.gap_limit = gap_limit + + def __aiter__(self) -> AsyncIterator[tuple[str, Any]]: + """Return an async iterator.""" + return self._async_iter() + + async def _async_iter(self) -> AsyncIterator[tuple[str, Any]]: + """Internal method that implements the async iterator.""" + assert self.manager.tx_storage.indexes is not None + assert self.manager.tx_storage.indexes.addresses is not None + addresses_index = self.manager.tx_storage.indexes.addresses + empty_addresses_counter = 0 + async for address_idx, address in self.address_iter: + yield ('address', (address_idx, address)) + + vertex_counter = 0 + for vertex_id in addresses_index.get_sorted_from_address(address): + tx = self.manager.tx_storage.get_transaction(vertex_id) + yield ('vertex', tx.to_json_extended()) + vertex_counter += 1 + + if vertex_counter == 0: + empty_addresses_counter += 1 + if empty_addresses_counter >= self.gap_limit: + break + else: + empty_addresses_counter = 0 + + +@implementer(IPushProducer) +class HistoryStreamer: + """A producer that pushes addresses and transactions to a websocket connection. + Each pushed address is automatically subscribed for real-time updates. + """ + def __init__(self, + *, + protocol: 'HathorAdminWebsocketProtocol', + stream_id: str, + search: GAPLimitSearch) -> None: + self.protocol = protocol + self.stream_id = stream_id + self.search_iter = aiter(search) + + self.reactor = self.protocol.factory.manager.reactor + + self.max_seconds_locking_event_loop = 1 + + self._paused = False + self._stop = False + + def start(self) -> Deferred[bool]: + """Start streaming items.""" + self.send_message({'type': 'stream:history:begin', 'id': self.stream_id}) + + # The websocket connection somehow instantiates an twisted.web.http.HTTPChannel object + # which register a producer. It seems the HTTPChannel is not used anymore after switching + # to websocket but it keep registered. So we have to unregister before registering a new + # producer. + self.protocol.unregisterProducer() + + self.protocol.registerProducer(self, True) + self.deferred: Deferred[bool] = Deferred() + self.resumeProducing() + return self.deferred + + def stop(self, success: bool) -> None: + """Stop streaming items.""" + self.protocol.unregisterProducer() + self.deferred.callback(success) + + def pauseProducing(self) -> None: + """Pause streaming. Called by twisted.""" + self._paused = True + + def stopProducing(self) -> None: + """Stop streaming. Called by twisted.""" + self._stop = True + self.stop(False) + + def resumeProducing(self) -> None: + """Resume streaming. Called by twisted.""" + self._paused = False + self._run() + + def _run(self) -> None: + """Run the streaming main loop.""" + coro = self._async_run() + Deferred.fromCoroutine(coro) + + async def _async_run(self): + """Internal method that runs the streaming main loop.""" + t0 = self.reactor.seconds() + + async for _type, data in self.search_iter: + if _type == 'address': + address_index, address = data + subscribed, errmsg = self.protocol.subscribe_address(address) + + if not subscribed: + self.send_message({ + 'type': 'stream:history:error', + 'id': self.stream_id, + 'errmsg': f'Address subscription failed: {errmsg}' + }) + self.stop(False) + return + + self.send_message({ + 'type': 'stream:history:address', + 'id': self.stream_id, + 'index': address_index, + 'address': address, + 'subscribed': subscribed, + }) + + elif _type == 'vertex': + self.send_message({ + 'type': 'stream:history:vertex', + 'id': self.stream_id, + 'data': data, + }) + + else: + assert False + + dt = self.reactor.seconds() - t0 + if dt > self.max_seconds_locking_event_loop: + # Let the event loop run at least once. + await deferLater(self.reactor, 0, lambda: None) + t0 = self.reactor.seconds() + + # The methods `pauseProducing()` and `stopProducing()` might be called during the + # call to `self.protocol.sendMessage()`. So both `_paused` and `_stop` might change + # during the loop. + if self._paused or self._stop: + break + + else: + self.send_message({'type': 'stream:history:end', 'id': self.stream_id}) + self.stop(True) + + def send_message(self, message: dict) -> None: + """Send a message to the websocket connection.""" + payload = json_dumpb(message) + self.protocol.sendMessage(payload)