Skip to content

Commit

Permalink
fix(ws): Add a graceful close mechanism to handle late messages and p…
Browse files Browse the repository at this point in the history
…revent errors
  • Loading branch information
msbrogli committed Sep 6, 2024
1 parent f3e3ab5 commit ec2e335
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 50 deletions.
2 changes: 2 additions & 0 deletions hathor/websocket/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ class StreamErrorMessage(StreamBase):
class StreamBeginMessage(StreamBase):
type: str = Field('stream:history:begin', const=True)
id: str
seq: int
window_size: Optional[int]


class StreamEndMessage(StreamBase):
type: str = Field('stream:history:end', const=True)
id: str
seq: int


class StreamVertexMessage(StreamBase):
Expand Down
1 change: 1 addition & 0 deletions hathor/websocket/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def fail_if_history_streaming_is_disabled(self) -> bool:

def _create_streamer(self, stream_id: str, search: AddressSearch, window_size: int | None) -> None:
"""Create the streamer and handle its callbacks."""
assert self._history_streamer is None
self._history_streamer = HistoryStreamer(protocol=self, stream_id=stream_id, search=search)
if window_size is not None:
if window_size < 0:
Expand Down
176 changes: 128 additions & 48 deletions hathor/websocket/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum, auto
from typing import TYPE_CHECKING, Optional

from twisted.internet.defer import Deferred
Expand All @@ -33,6 +34,27 @@
from hathor.websocket.protocol import HathorAdminWebsocketProtocol


class StreamerState(Enum):
NOT_STARTED = auto()
ACTIVE = auto()
PAUSED = auto()
CLOSING = auto()
CLOSED = auto()

def can_transition_to(self, destination: 'StreamerState') -> bool:
"""Checks if the transition to the destination state is valid."""
return destination in VALID_TRANSITIONS[self]


VALID_TRANSITIONS = {
StreamerState.NOT_STARTED: {StreamerState.ACTIVE},
StreamerState.ACTIVE: {StreamerState.PAUSED, StreamerState.CLOSING, StreamerState.CLOSED},
StreamerState.PAUSED: {StreamerState.ACTIVE, StreamerState.CLOSING, StreamerState.CLOSED},
StreamerState.CLOSING: {StreamerState.CLOSED},
StreamerState.CLOSED: set()
}


@implementer(IPushProducer)
class HistoryStreamer:
"""A producer that pushes addresses and transactions to a websocket connection.
Expand Down Expand Up @@ -72,23 +94,34 @@ def __init__(self,

self.deferred: Deferred[bool] = Deferred()

# Statistics.
# Statistics
# ----------
self.stats_log_interval = self.STATS_LOG_INTERVAL
self.stats_total_messages: int = 0
self.stats_sent_addresses: int = 0
self.stats_sent_vertices: int = 0

# Execution control.
self._started = False
self._is_running = False
self._paused = False
self._stop = False
# Execution control
# -----------------
self._state = StreamerState.NOT_STARTED
# Used to mark that the streamer is currently running its main loop and sending messages.
self._is_main_loop_running = False
# Used to mark that the streamer was paused by the transport layer.
self._is_paused_by_transport = False

# Flow control.
# Flow control
# ------------
self._next_sequence_number: int = 0
self._last_ack: int = -1
self._sliding_window_size: Optional[int] = self.DEFAULT_SLIDING_WINDOW_SIZE

def get_next_seq(self) -> int:
assert self._state is not StreamerState.CLOSING
assert self._state is not StreamerState.CLOSED
seq = self._next_sequence_number
self._next_sequence_number += 1
return seq

def set_sliding_window_size(self, size: Optional[int]) -> None:
"""Set a new sliding window size for flow control. If size is none, disables flow control.
"""
Expand All @@ -102,87 +135,137 @@ def set_ack(self, ack: int) -> None:
If the new value is bigger than the previous value, the streaming might be resumed.
"""
if ack <= self._last_ack:
if ack == self._last_ack:
# We might receive outdated or duplicate ACKs, and we can safely ignore them.
return
if ack < self._last_ack:
# ACK got smaller. Something is wrong...
self.send_message(StreamErrorMessage(
id=self.stream_id,
errmsg=f'Outdated ACK received. Skipping it... (ack={ack})'
errmsg=f'Outdated ACK received (ack={ack})'
))
self.stop(False)
return
if ack >= self._next_sequence_number:
# ACK is higher than the last message sent. Something is wrong...
self.send_message(StreamErrorMessage(
id=self.stream_id,
errmsg=f'Received ACK is higher than the last sent message. Skipping it... (ack={ack})'
errmsg=f'Received ACK is higher than the last sent message (ack={ack})'
))
self.stop(False)
return
self._last_ack = ack
self.resume_if_possible()
if self._state is not StreamerState.CLOSING:
closing_ack = self._next_sequence_number - 1
if ack == closing_ack:
self.stop(True)
else:
self.resume_if_possible()

def resume_if_possible(self) -> None:
if not self._started:
"""Resume sending messages if possible."""
if not self._state.can_transition_to(StreamerState.ACTIVE):
return
if self._is_main_loop_running:
return
if self._is_paused_by_transport:
return
if not self.should_pause_streaming() and not self._is_running:
self.resumeProducing()
if self.should_pause_streaming():
return
self._run()

def set_state(self, new_state: StreamerState) -> None:
"""Set a new state for the streamer."""
if self._state == new_state:
return
assert self._state.can_transition_to(new_state)
self._state = new_state

def start(self) -> Deferred[bool]:
"""Start streaming items."""
assert self._state is StreamerState.NOT_STARTED

# The websocket connection somehow instantiates an twisted.web.http.HTTPChannel object
# which register a producer. It seems the HTTPChannel is not used anymore after switching
# to websocket but it keep registered. So we have to unregister before registering a new
# producer.
if self.protocol.transport.producer:
self.protocol.unregisterProducer()

self.protocol.registerProducer(self, True)

assert not self._started
self._started = True
self.send_message(StreamBeginMessage(id=self.stream_id, window_size=self._sliding_window_size))
self.resumeProducing()
self.send_message(StreamBeginMessage(
id=self.stream_id,
seq=self.get_next_seq(),
window_size=self._sliding_window_size,
))
self.resume_if_possible()
return self.deferred

def stop(self, success: bool) -> None:
"""Stop streaming items."""
assert self._started
self._stop = True
self._started = False
if not self._state.can_transition_to(StreamerState.CLOSED):
# Do nothing if the streamer has already been stopped.
self.protocol.log.warn('stop called in an unexpected state', state=self._state)
return
self.set_state(StreamerState.CLOSED)
self.protocol.unregisterProducer()
self.deferred.callback(success)

def gracefully_close(self) -> None:
"""Gracefully close the stream by sending the StreamEndMessage and waiting for its ack."""
if not self._state.can_transition_to(StreamerState.CLOSING):
return
self.send_message(StreamEndMessage(id=self.stream_id, seq=self.get_next_seq()))
self.set_state(StreamerState.CLOSING)

def pauseProducing(self) -> None:
"""Pause streaming. Called by twisted."""
self._paused = True
if not self._state.can_transition_to(StreamerState.PAUSED):
self.protocol.log.warn('pause requested in an unexpected state', state=self._state)
return
self.set_state(StreamerState.PAUSED)
self._is_paused_by_transport = True

def stopProducing(self) -> None:
"""Stop streaming. Called by twisted."""
self._stop = True
if not self._state.can_transition_to(StreamerState.CLOSED):
self.protocol.log.warn('stopped requested in an unexpected state', state=self._state)
return
self.stop(False)

def resumeProducing(self) -> None:
"""Resume streaming. Called by twisted."""
self._paused = False
self._run()

def _run(self) -> None:
"""Run the streaming main loop."""
coro = self._async_run()
Deferred.fromCoroutine(coro)
if not self._state.can_transition_to(StreamerState.ACTIVE):
self.protocol.log.warn('resume requested in an unexpected state', state=self._state)
return
self._is_paused_by_transport = False
self.resume_if_possible()

def should_pause_streaming(self) -> bool:
"""Return true if the streaming should pause due to the flow control mechanism."""
if self._sliding_window_size is None:
return False
stop_value = self._last_ack + self._sliding_window_size + 1
if self._next_sequence_number < stop_value:
return False
return True

def _run(self) -> None:
"""Run the streaming main loop."""
if not self._state.can_transition_to(StreamerState.ACTIVE):
self.protocol.log.warn('_run() called in an unexpected state', state=self._state)
return
coro = self._async_run()
Deferred.fromCoroutine(coro)

async def _async_run(self):
assert not self._is_running
self._is_running = True
assert not self._is_main_loop_running
self.set_state(StreamerState.ACTIVE)
self._is_main_loop_running = True
try:
await self._async_run_unsafe()
finally:
self._is_running = False
self._is_main_loop_running = False

async def _async_run_unsafe(self):
"""Internal method that runs the streaming main loop."""
Expand All @@ -204,7 +287,7 @@ async def _async_run_unsafe(self):
self.stats_sent_addresses += 1
self.send_message(StreamAddressMessage(
id=self.stream_id,
seq=self._next_sequence_number,
seq=self.get_next_seq(),
index=item.index,
address=item.address,
subscribed=subscribed,
Expand All @@ -214,42 +297,39 @@ async def _async_run_unsafe(self):
self.stats_sent_vertices += 1
self.send_message(StreamVertexMessage(
id=self.stream_id,
seq=self._next_sequence_number,
seq=self.get_next_seq(),
data=item.vertex.to_json_extended(),
))

case _:
assert False

self._next_sequence_number += 1
if self.should_pause_streaming():
break

# The methods `pauseProducing()` and `stopProducing()` might be called during the
# call to `self.protocol.sendMessage()`. So both `_paused` and `_stop` might change
# during the loop.
if self._paused or self._stop:
break

self.stats_total_messages += 1
if self.stats_total_messages % self.stats_log_interval == 0:
self.protocol.log.info('websocket streaming statistics',
total_messages=self.stats_total_messages,
sent_vertices=self.stats_sent_vertices,
sent_addresses=self.stats_sent_addresses)

# The methods `pauseProducing()` and `stopProducing()` might be called during the
# call to `self.protocol.sendMessage()`. So the streamer state might change during
# the loop.
if self._state is not StreamerState.ACTIVE:
break

# Limit blocking of the event loop to a maximum of N seconds.
dt = self.reactor.seconds() - t0
if dt > self.max_seconds_locking_event_loop:
# Let the event loop run at least once.
await deferLater(self.reactor, 0, lambda: None)
t0 = self.reactor.seconds()

else:
if self._stop:
# If the streamer has been stopped, there is nothing else to do.
return
self.send_message(StreamEndMessage(id=self.stream_id))
self.stop(True)
# Iterator is empty so we can close the stream.
self.gracefully_close()

def send_message(self, message: StreamBase) -> None:
"""Send a message to the websocket connection."""
Expand Down
11 changes: 9 additions & 2 deletions tests/websocket/test_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from hathor.wallet import HDWallet
from hathor.websocket.factory import HathorAdminWebsocketFactory
from hathor.websocket.iterators import AddressItem, ManualAddressSequencer, gap_limit_search
from hathor.websocket.streamer import HistoryStreamer
from hathor.websocket.streamer import HistoryStreamer, StreamerState
from tests.unittest import TestCase
from tests.utils import GENESIS_ADDRESS_B58

Expand Down Expand Up @@ -60,7 +60,7 @@ def test_streamer(self) -> None:
'data': genesis.to_json_extended(),
})
expected_result.append({'type': 'stream:history:end', 'id': stream_id})
for index, item in enumerate(expected_result[1:-1]):
for index, item in enumerate(expected_result):
item['seq'] = index

# Create both the address iterator and the GAP limit searcher.
Expand All @@ -86,6 +86,13 @@ def test_streamer(self) -> None:
# Run the streamer.
manager.reactor.advance(10)

# Check the streamer is waiting for the last ACK.
self.assertTrue(streamer._state, StreamerState.CLOSING)
streamer.set_ack(1)
self.assertTrue(streamer._state, StreamerState.CLOSING)
streamer.set_ack(len(expected_result) - 1)
self.assertTrue(streamer._state, StreamerState.CLOSED)

# Check the results.
items_iter = self._parse_ws_raw(transport.value())
result = list(items_iter)
Expand Down

0 comments on commit ec2e335

Please sign in to comment.