diff --git a/zigpy_znp/api.py b/zigpy_znp/api.py index 6f95a310..b8e34bd2 100644 --- a/zigpy_znp/api.py +++ b/zigpy_znp/api.py @@ -31,7 +31,11 @@ CallbackResponseListener, ) from zigpy_znp.frames import GeneralFrame -from zigpy_znp.exceptions import CommandNotRecognized, InvalidCommandResponse +from zigpy_znp.exceptions import ( + InvalidFrame, + CommandNotRecognized, + InvalidCommandResponse, +) from zigpy_znp.types.nvids import ExNvIds, OsalNvIds if typing.TYPE_CHECKING: @@ -715,7 +719,7 @@ async def connect(self, *, test_port=True) -> None: self.close() raise - LOGGER.debug("Connected to %s", self._uart.url) + LOGGER.debug("Connected to %s", self._uart.get_url()) def connection_made(self) -> None: """ @@ -792,7 +796,7 @@ def remove_listener(self, listener: BaseResponseListener) -> None: counts[OneShotResponseListener], ) - def frame_received(self, frame: GeneralFrame) -> bool | None: + def frame_received(self, frame: GeneralFrame) -> None: """ Called when a frame has been received. Returns whether or not the frame was handled by any listener. @@ -802,7 +806,7 @@ def frame_received(self, frame: GeneralFrame) -> bool | None: if frame.header not in c.COMMANDS_BY_ID: LOGGER.error("Received an unknown frame: %s", frame) - return False + raise InvalidFrame("Invalid command id") command_cls = c.COMMANDS_BY_ID[frame.header] @@ -813,7 +817,9 @@ def frame_received(self, frame: GeneralFrame) -> bool | None: # https://github.com/home-assistant/core/issues/50005 if command_cls == c.ZDO.ParentAnnceRsp.Callback: LOGGER.warning("Failed to parse broken %s as %s", frame, command_cls) - return False + raise InvalidFrame( + "Parsing frame %s ad command %s failed", frame, command_cls + ) raise @@ -844,8 +850,6 @@ def frame_received(self, frame: GeneralFrame) -> bool | None: if not matched: self._unhandled_command(command) - return matched - def _unhandled_command(self, command: t.CommandBase): """ Called when a command that is not handled by any listener is received. diff --git a/zigpy_znp/thread.py b/zigpy_znp/thread.py new file mode 100644 index 00000000..44b6fe29 --- /dev/null +++ b/zigpy_znp/thread.py @@ -0,0 +1,122 @@ +import asyncio +from concurrent.futures import ThreadPoolExecutor +import functools +import logging +import sys + +LOGGER = logging.getLogger(__name__) + + +class EventLoopThread: + """Run a parallel event loop in a separate thread.""" + + def __init__(self): + self.loop = None + self.thread_complete = None + + def run_coroutine_threadsafe(self, coroutine): + current_loop = asyncio.get_event_loop() + future = asyncio.run_coroutine_threadsafe(coroutine, self.loop) + return asyncio.wrap_future(future, loop=current_loop) + + def _thread_main(self, init_task): + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + try: + self.loop.run_until_complete(init_task) + self.loop.run_forever() + finally: + self.loop.close() + self.loop = None + + async def start(self): + current_loop = asyncio.get_event_loop() + if self.loop is not None and not self.loop.is_closed(): + return + + executor_opts = {"max_workers": 1} + if sys.version_info[:2] >= (3, 6): + executor_opts["thread_name_prefix"] = __name__ + executor = ThreadPoolExecutor(**executor_opts) + + thread_started_future = current_loop.create_future() + + async def init_task(): + current_loop.call_soon_threadsafe(thread_started_future.set_result, None) + + # Use current loop so current loop has a reference to the long-running thread + # as one of its tasks + thread_complete = current_loop.run_in_executor( + executor, self._thread_main, init_task() + ) + self.thread_complete = thread_complete + current_loop.call_soon(executor.shutdown, False) + await thread_started_future + return thread_complete + + def force_stop(self): + if self.loop is None: + return + + def cancel_tasks_and_stop_loop(): + tasks = asyncio.all_tasks(loop=self.loop) + + for task in tasks: + self.loop.call_soon_threadsafe(task.cancel) + + gather = asyncio.gather(*tasks, return_exceptions=True) + gather.add_done_callback( + lambda _: self.loop.call_soon_threadsafe(self.loop.stop) + ) + + self.loop.call_soon_threadsafe(cancel_tasks_and_stop_loop) + + +class ThreadsafeProxy: + """Proxy class which enforces threadsafe non-blocking calls + This class can be used to wrap an object to ensure any calls + using that object's methods are done on a particular event loop + """ + + def __init__(self, obj, obj_loop): + self._obj = obj + self._obj_loop = obj_loop + + def __getattr__(self, name): + func = getattr(self._obj, name) + if not callable(func): + raise TypeError( + "Can only use ThreadsafeProxy with callable attributes: {}.{}".format( + self._obj.__class__.__name__, name + ) + ) + + def func_wrapper(*args, **kwargs): + loop = self._obj_loop + curr_loop = asyncio.get_event_loop() + call = functools.partial(func, *args, **kwargs) + if loop == curr_loop: + return call() + if loop.is_closed(): + # Disconnected + LOGGER.warning("Attempted to use a closed event loop") + return + if asyncio.iscoroutinefunction(func): + future = asyncio.run_coroutine_threadsafe(call(), loop) + return asyncio.wrap_future(future, loop=curr_loop) + else: + + def check_result_wrapper(): + result = call() + if result is not None: + raise TypeError( + ( + "ThreadsafeProxy can only wrap functions with no return" + "value \nUse an async method to return values: {}.{}" + ).format(self._obj.__class__.__name__, name) + ) + + loop.call_soon_threadsafe(check_result_wrapper) + + return func_wrapper \ No newline at end of file diff --git a/zigpy_znp/uart.py b/zigpy_znp/uart.py index ec81a070..36f0cb74 100644 --- a/zigpy_znp/uart.py +++ b/zigpy_znp/uart.py @@ -10,6 +10,7 @@ import zigpy_znp.frames as frames import zigpy_znp.logger as log from zigpy_znp.types import Bytes +from zigpy_znp.thread import EventLoopThread, ThreadsafeProxy from zigpy_znp.exceptions import InvalidFrame LOGGER = logging.getLogger(__name__) @@ -25,6 +26,7 @@ def __init__(self, api, *, url: str | None = None) -> None: self._api = api self._transport = None self._connected_event = asyncio.Event() + self._connection_done_event = asyncio.Event() self.url = url @@ -46,6 +48,9 @@ def connection_lost(self, exc: Exception | None) -> None: if exc is not None: LOGGER.warning("Lost connection", exc_info=exc) + if self._connection_done_event: + self._connection_done_event.set() + if self._api is not None: self._api.connection_lost(exc) @@ -157,8 +162,11 @@ def __repr__(self) -> str: f">" ) + async def get_url(self): + return self.url + -async def connect(config: conf.ConfigType, api) -> ZnpMtProtocol: +async def _connect(config: conf.ConfigType, api) -> ZnpMtProtocol: loop = asyncio.get_running_loop() port = config[conf.CONF_DEVICE_PATH] @@ -181,3 +189,25 @@ async def connect(config: conf.ConfigType, api) -> ZnpMtProtocol: LOGGER.debug("Connected to %s at %s baud", port, baudrate) return protocol + + +async def connect( + config: conf.ConfigType, api, use_thread=True +) -> ZnpMtProtocol | ThreadsafeProxy: + if use_thread: + application = ThreadsafeProxy(api, asyncio.get_event_loop()) + thread = EventLoopThread() + await thread.start() + try: + protocol = await thread.run_coroutine_threadsafe( + _connect(config, application) + ) + except Exception: + thread.force_stop() + raise + + thread_safe_protocol = ThreadsafeProxy(protocol, thread.loop) + return thread_safe_protocol + else: + protocol = await _connect(config, api) + return protocol