Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent task cancellation from propagating to ASH #628

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 46 additions & 14 deletions bellows/ash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import abc
import asyncio
import binascii
from collections.abc import Coroutine
import dataclasses
import enum
import logging
import sys
import time
import typing

if sys.version_info[:2] < (3, 11):
from async_timeout import timeout as asyncio_timeout # pragma: no cover
Expand Down Expand Up @@ -55,7 +57,7 @@ class Reserved(enum.IntEnum):

# Maximum number of DATA frames the NCP can transmit without having received
# acknowledgements
TX_K = 1
TX_K = 1 # TODO: investigate why this cannot be raised without causing a firmware crash

# Maximum number of consecutive timeouts allowed while waiting to receive an ACK before
# going to the FAILED state. The value 0 prevents the NCP from entering the error state
Expand All @@ -81,6 +83,23 @@ def generate_random_sequence(length: int) -> bytes:
# Since the sequence is static for every frame, we only need to generate it once
PSEUDO_RANDOM_DATA_SEQUENCE = generate_random_sequence(256)

if sys.version_info[:2] < (3, 12):
create_eager_task = asyncio.create_task
else:
_T = typing.TypeVar("T")

def create_eager_task(
coro: Coroutine[typing.Any, typing.Any, _T],
*,
name: str | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> asyncio.Task[_T]:
"""Create a task from a coroutine and schedule it to run immediately."""
if loop is None:
loop = asyncio.get_running_loop()

return asyncio.Task(coro, loop=loop, name=name, eager_start=True)


class NcpState(enum.Enum):
CONNECTED = "connected"
Expand Down Expand Up @@ -463,15 +482,14 @@ def data_received(self, data: bytes) -> None:
def _handle_ack(self, frame: DataFrame | AckFrame) -> None:
# Note that ackNum is the number of the next frame the receiver expects and it
# is one greater than the last frame received.
ack_num = (frame.ack_num - 1) % 8
for ack_num_offset in range(-TX_K, 0):
ack_num = (frame.ack_num + ack_num_offset) % 8
fut = self._pending_data_frames.get(ack_num)

fut = self._pending_data_frames.get(ack_num)
if fut is None or fut.done():
continue

if fut is None or fut.done():
return

# _LOGGER.debug("Resolving frame %d", ack_num)
self._pending_data_frames[ack_num].set_result(True)
self._pending_data_frames[ack_num].set_result(True)

def frame_received(self, frame: AshFrame) -> None:
_LOGGER.debug("Received frame %r", frame)
Expand Down Expand Up @@ -537,13 +555,16 @@ def error_frame_received(self, frame: ErrorFrame) -> None:
self._ncp_state = NcpState.FAILED

# Cancel all pending requests
exc = NcpFailure(code=self._ncp_reset_code)
self._enter_failed_state(self._ncp_reset_code)

def _enter_failed_state(self, reset_code: t.NcpResetCode) -> None:
exc = NcpFailure(code=reset_code)

for fut in self._pending_data_frames.values():
if not fut.done():
fut.set_exception(exc)

self._ezsp_protocol.reset_received(frame.reset_code)
self._ezsp_protocol.reset_received(reset_code)

def _write_frame(
self,
Expand Down Expand Up @@ -582,7 +603,7 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
for attempt in range(ACK_TIMEOUTS):
if self._ncp_state == NcpState.FAILED:
_LOGGER.debug(
"NCP is in a failed state, not re-sending: %r", frame
"NCP is in a failed state, not sending: %r", frame
)
raise NcpFailure(
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
Expand Down Expand Up @@ -618,6 +639,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
self._change_ack_timeout((7 / 8) * self._t_rx_ack + 0.5 * delta)

if attempt >= ACK_TIMEOUTS - 1:
self._enter_failed_state(
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
)
raise
except NcpFailure:
_LOGGER.debug(
Expand All @@ -635,6 +659,9 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
self._change_ack_timeout(2 * self._t_rx_ack)

if attempt >= ACK_TIMEOUTS - 1:
self._enter_failed_state(
t.NcpResetCode.ERROR_EXCEEDED_MAXIMUM_ACK_TIMEOUT_COUNT
)
raise
else:
# Whenever an acknowledgement is received, t_rx_ack is set to
Expand All @@ -649,9 +676,14 @@ async def _send_data_frame(self, frame: AshFrame) -> None:
self._pending_data_frames.pop(frm_num)

async def send_data(self, data: bytes) -> None:
await self._send_data_frame(
# All of the other fields will be set during transmission/retries
DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data)
# Sending data is a critical operation and cannot really be cancelled
await asyncio.shield(
create_eager_task(
self._send_data_frame(
# All of the other fields will be set during transmission/retries
DataFrame(frm_num=None, re_tx=None, ack_num=None, ezsp_frame=data)
)
)
)

def send_reset(self) -> None:
Expand Down
Loading