Skip to content

Commit 81ad1ee

Browse files
fix(client_sync): improve event loop initialization and background task management (#12)
1 parent 7f84826 commit 81ad1ee

File tree

3 files changed

+116
-95
lines changed

3 files changed

+116
-95
lines changed

src/wokwi_client/client_sync.py

Lines changed: 71 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -39,34 +39,31 @@ class WokwiClientSync:
3939
tracked, so we can cancel & drain them on `disconnect()`.
4040
"""
4141

42-
# Public attributes mirrored for convenience
43-
version: str
44-
last_pause_nanos: int # this proxy resolves via __getattr__
45-
4642
def __init__(self, token: str, server: str | None = None):
47-
# Create a fresh event loop + thread (daemon so it won't prevent process exit).
43+
# Create a new event loop for the background thread
4844
self._loop = asyncio.new_event_loop()
45+
# Event to signal that the event loop is running
46+
self._loop_started_event = threading.Event()
47+
# Start background thread running the event loop
4948
self._thread = threading.Thread(
5049
target=self._run_loop, args=(self._loop,), daemon=True, name="wokwi-sync-loop"
5150
)
5251
self._thread.start()
53-
54-
# Underlying async client
52+
# **Wait until loop is fully started before proceeding** (prevents race conditions)
53+
if not self._loop_started_event.wait(timeout=8.0): # timeout to avoid deadlock
54+
raise RuntimeError("WokwiClientSync event loop failed to start")
55+
# Initialize underlying async client on the running loop
5556
self._async_client = WokwiClient(token, server)
56-
57-
# Mirror library version for quick access
58-
self.version = self._async_client.version
59-
60-
# Track background tasks created via run_coroutine_threadsafe (serial monitors)
57+
# Track background monitor tasks (futures) for cancellation on exit
6158
self._bg_futures: set[Future[Any]] = set()
62-
63-
# Idempotent disconnect guard
59+
# Flag to avoid double-closing
6460
self._closed = False
6561

66-
@staticmethod
67-
def _run_loop(loop: asyncio.AbstractEventLoop) -> None:
68-
"""Background thread loop runner."""
62+
def _run_loop(self, loop: asyncio.AbstractEventLoop) -> None:
63+
"""Target function for the background thread: runs the asyncio event loop."""
6964
asyncio.set_event_loop(loop)
65+
# Signal that the loop is now running and ready to accept tasks
66+
loop.call_soon(self._loop_started_event.set)
7067
loop.run_forever()
7168

7269
# ----- Internal helpers -------------------------------------------------
@@ -75,8 +72,11 @@ def _submit(self, coro: Coroutine[Any, Any, T]) -> Future[T]:
7572
return asyncio.run_coroutine_threadsafe(coro, self._loop)
7673

7774
def _call(self, coro: Coroutine[Any, Any, T]) -> T:
78-
"""Submit a coroutine to the loop and block until it completes (or raises)."""
79-
return self._submit(coro).result()
75+
"""Submit a coroutine to the background loop and wait for result."""
76+
if self._closed:
77+
raise RuntimeError("Cannot call methods on a closed WokwiClientSync")
78+
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
79+
return future.result() # Block until the coroutine completes or raises
8080

8181
def _add_bg_future(self, fut: Future[Any]) -> None:
8282
"""Track a background future so we can cancel & drain on shutdown."""
@@ -96,37 +96,35 @@ def connect(self) -> dict[str, Any]:
9696
return self._call(self._async_client.connect())
9797

9898
def disconnect(self) -> None:
99-
"""Disconnect and stop the background loop.
100-
101-
Order matters:
102-
1) Cancel and drain background serial-monitor futures.
103-
2) Disconnect the underlying transport.
104-
3) Stop the loop and join the thread.
105-
Safe to call multiple times.
106-
"""
10799
if self._closed:
108100
return
109-
self._closed = True
110101

111102
# (1) Cancel + drain monitors
112103
for fut in list(self._bg_futures):
113104
fut.cancel()
114105
for fut in list(self._bg_futures):
115106
with contextlib.suppress(FutureTimeoutError, Exception):
116-
# Give each monitor a short window to handle cancellation cleanly.
117107
fut.result(timeout=1.0)
118108
self._bg_futures.discard(fut)
119109

120110
# (2) Disconnect transport
121111
with contextlib.suppress(Exception):
122-
self._call(self._async_client._transport.close())
112+
fut = asyncio.run_coroutine_threadsafe(self._async_client.disconnect(), self._loop)
113+
fut.result(timeout=2.0)
123114

124115
# (3) Stop loop / join thread
125116
if self._loop.is_running():
126117
self._loop.call_soon_threadsafe(self._loop.stop)
127118
if self._thread.is_alive():
128119
self._thread.join(timeout=5.0)
129120

121+
# (4) Close loop
122+
with contextlib.suppress(Exception):
123+
self._loop.close()
124+
125+
# (5) Mark closed at the very end
126+
self._closed = True
127+
130128
# ----- Serial monitoring ------------------------------------------------
131129
def serial_monitor(self, callback: Callable[[bytes], Any]) -> None:
132130
"""
@@ -138,17 +136,25 @@ def serial_monitor(self, callback: Callable[[bytes], Any]) -> None:
138136
"""
139137

140138
async def _runner() -> None:
141-
async for line in monitor_lines(self._async_client._transport):
142-
try:
143-
maybe_awaitable = callback(line)
144-
if inspect.isawaitable(maybe_awaitable):
145-
await maybe_awaitable
146-
except Exception:
147-
# Keep the monitor alive even if the callback throws.
148-
pass
149-
150-
fut = self._submit(_runner())
151-
self._add_bg_future(fut)
139+
try:
140+
# **Prepare to receive serial events before enabling monitor**
141+
# (monitor_lines will subscribe to serial events internally)
142+
async for line in monitor_lines(self._async_client._transport):
143+
try:
144+
result = callback(line) # invoke callback with the raw bytes line
145+
if inspect.isawaitable(result):
146+
await result # await if callback is async
147+
except Exception:
148+
# Swallow exceptions from callback to keep monitor alive
149+
pass
150+
finally:
151+
# Remove this task’s future from the set when done
152+
self._bg_futures.discard(task_future)
153+
154+
# Schedule the serial monitor runner on the event loop:
155+
task_future = asyncio.run_coroutine_threadsafe(_runner(), self._loop)
156+
self._bg_futures.add(task_future)
157+
# (No return value; monitoring happens in background)
152158

153159
def serial_monitor_cat(self, decode_utf8: bool = True, errors: str = "replace") -> None:
154160
"""
@@ -160,34 +166,32 @@ def serial_monitor_cat(self, decode_utf8: bool = True, errors: str = "replace")
160166
"""
161167

162168
async def _runner() -> None:
163-
async for line in monitor_lines(self._async_client._transport):
164-
try:
165-
if decode_utf8:
166-
try:
167-
print(line.decode("utf-8", errors=errors), end="", flush=True)
168-
except UnicodeDecodeError:
169+
try:
170+
# **Subscribe to serial events before reading output**
171+
async for line in monitor_lines(self._async_client._transport):
172+
try:
173+
if decode_utf8:
174+
# Decode bytes to string (handle errors per parameter)
175+
text = line.decode("utf-8", errors=errors)
176+
print(text, end="", flush=True)
177+
else:
178+
# Print raw bytes
169179
print(line, end="", flush=True)
170-
else:
171-
print(line, end="", flush=True)
172-
except Exception:
173-
# Keep the monitor alive even if printing raises intermittently.
174-
pass
180+
except Exception:
181+
# Swallow print errors to keep stream alive
182+
pass
183+
finally:
184+
self._bg_futures.discard(task_future)
175185

176-
fut = self._submit(_runner())
177-
self._add_bg_future(fut)
186+
task_future = asyncio.run_coroutine_threadsafe(_runner(), self._loop)
187+
self._bg_futures.add(task_future)
188+
# (No return; printing continues in background)
178189

179190
def stop_serial_monitors(self) -> None:
180-
"""
181-
Cancel and drain all running serial monitors without disconnecting.
182-
183-
Useful if you want to stop printing but keep the connection alive.
184-
"""
191+
"""Stop all active serial monitor background tasks."""
185192
for fut in list(self._bg_futures):
186193
fut.cancel()
187-
for fut in list(self._bg_futures):
188-
with contextlib.suppress(FutureTimeoutError, Exception):
189-
fut.result(timeout=1.0)
190-
self._bg_futures.discard(fut)
194+
self._bg_futures.clear()
191195

192196
# ----- Dynamic method wrapping -----------------------------------------
193197
def __getattr__(self, name: str) -> Any:
@@ -197,16 +201,17 @@ def __getattr__(self, name: str) -> Any:
197201
If the attribute on `WokwiClient` is a coroutine function, return a
198202
sync wrapper that blocks until the coroutine completes.
199203
"""
200-
# Explicit methods above (serial monitors) take precedence.
204+
# Explicit methods (like serial_monitor functions above) take precedence over __getattr__
201205
attr = getattr(self._async_client, name)
202206
if callable(attr):
207+
# Get the function object from WokwiClient class (unbound) to check if coroutine
203208
func = getattr(WokwiClient, name, None)
204209
if func is not None and inspect.iscoroutinefunction(func):
205-
210+
# Wrap coroutine method to run in background loop
206211
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
207212
return self._call(attr(*args, **kwargs))
208213

209214
sync_wrapper.__name__ = name
210-
sync_wrapper.__doc__ = func.__doc__
215+
sync_wrapper.__doc__ = getattr(func, "__doc__", "")
211216
return sync_wrapper
212217
return attr

src/wokwi_client/serial.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111

1212
async def monitor_lines(transport: Transport) -> AsyncGenerator[bytes, None]:
13+
"""
14+
Monitor the serial output lines.
15+
"""
1316
await transport.request("serial-monitor:listen", {})
1417
with EventQueue(transport, "serial-monitor:data") as queue:
1518
while True:

src/wokwi_client/transport.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# SPDX-License-Identifier: MIT
44

55
import asyncio
6+
import contextlib
67
import json
78
import os
89
import warnings
@@ -48,31 +49,27 @@ async def connect(self, throw_error: bool = True) -> dict[str, Any]:
4849
raise ProtocolError(f"Unsupported protocol handshake: {hello}")
4950
hello_msg = cast(HelloMessage, hello)
5051
self._closed = False
51-
# Start background message processor
52+
# Start background message processor after successful hello.
5253
self._recv_task = asyncio.create_task(self._background_recv(throw_error))
5354
return {"version": hello_msg["appVersion"]}
5455

5556
async def close(self) -> None:
5657
self._closed = True
5758
if self._recv_task:
5859
self._recv_task.cancel()
59-
try:
60+
with contextlib.suppress(asyncio.CancelledError):
6061
await self._recv_task
61-
except asyncio.CancelledError:
62-
pass
6362
if self._ws:
6463
await self._ws.close()
6564

6665
def add_event_listener(self, event_type: str, listener: Callable[[EventMessage], Any]) -> None:
67-
"""Register a listener for a specific event type."""
6866
if event_type not in self._event_listeners:
6967
self._event_listeners[event_type] = []
7068
self._event_listeners[event_type].append(listener)
7169

7270
def remove_event_listener(
7371
self, event_type: str, listener: Callable[[EventMessage], Any]
7472
) -> None:
75-
"""Remove a previously registered listener for a specific event type."""
7673
if event_type in self._event_listeners:
7774
self._event_listeners[event_type] = [
7875
registered_listener
@@ -90,52 +87,72 @@ async def _dispatch_event(self, event_msg: EventMessage) -> None:
9087
await result
9188

9289
async def request(self, command: str, params: dict[str, Any]) -> ResponseMessage:
93-
msg_id = str(self._next_id)
94-
self._next_id += 1
9590
if self._ws is None:
9691
raise WokwiError("Not connected")
92+
msg_id = str(self._next_id)
93+
self._next_id += 1
94+
9795
loop = asyncio.get_running_loop()
9896
future: asyncio.Future[ResponseMessage] = loop.create_future()
9997
self._response_futures[msg_id] = future
98+
10099
await self._ws.send(
101100
json.dumps({"type": "command", "command": command, "params": params, "id": msg_id})
102101
)
103102
try:
104103
resp_msg_resp = await future
105104
if resp_msg_resp.get("error"):
106-
result = resp_msg_resp["result"]
107-
raise ServerError(result["message"])
105+
result = resp_msg_resp.get("result", {})
106+
raise ServerError(result.get("message", "Unknown server error"))
108107
return resp_msg_resp
109108
finally:
110-
del self._response_futures[msg_id]
109+
# Remove future mapping if still present (be defensive)
110+
self._response_futures.pop(msg_id, None)
111111

112-
async def _background_recv(self, throw_error: bool = True) -> None:
112+
async def _background_recv(self, throw_error: bool = True) -> None: # noqa: PLR0912
113113
try:
114114
while not self._closed and self._ws is not None:
115115
msg: IncomingMessage = await self._recv()
116116
if msg["type"] == MSG_TYPE_EVENT:
117-
resp_msg_event = cast(EventMessage, msg)
118-
await self._dispatch_event(resp_msg_event)
117+
await self._dispatch_event(cast(EventMessage, msg))
119118
elif msg["type"] == MSG_TYPE_RESPONSE:
120119
resp_msg_resp = cast(ResponseMessage, msg)
121-
future = self._response_futures.get(resp_msg_resp["id"])
120+
resp_id = str(resp_msg_resp.get("id"))
121+
future = self._response_futures.get(resp_id)
122122
if future is None or future.done():
123123
continue
124124
future.set_result(resp_msg_resp)
125-
except (websockets.ConnectionClosed, asyncio.CancelledError):
126-
pass
125+
except asyncio.CancelledError:
126+
# Expected during shutdown via close()
127+
raise
128+
except websockets.ConnectionClosed as e:
129+
# Mark closed and fail pending futures to avoid hangs.
130+
self._closed = True
131+
for fut in list(self._response_futures.values()):
132+
if not fut.done():
133+
fut.set_exception(e)
134+
with contextlib.suppress(Exception):
135+
if self._ws:
136+
await self._ws.close()
137+
if throw_error:
138+
raise
127139
except Exception as e:
128140
warnings.warn(f"Background recv error: {e}", RuntimeWarning)
129-
130141
if throw_error:
131142
self._closed = True
132-
# Cancel all pending response futures
133-
for future in self._response_futures.values():
134-
if not future.done():
135-
future.set_exception(e)
136-
if self._ws:
137-
await self._ws.close()
143+
for fut in list(self._response_futures.values()):
144+
if not fut.done():
145+
fut.set_exception(e)
146+
with contextlib.suppress(Exception):
147+
if self._ws:
148+
await self._ws.close()
138149
raise
150+
finally:
151+
# If we’re exiting the loop and marked closed, ensure no future hangs.
152+
if self._closed:
153+
for fut in list(self._response_futures.values()):
154+
if not fut.done():
155+
fut.set_exception(RuntimeError("Transport receive loop exited"))
139156

140157
async def _recv(self) -> IncomingMessage:
141158
if self._ws is None:
@@ -153,10 +170,6 @@ async def _recv(self) -> IncomingMessage:
153170
if message["type"] == "error":
154171
raise WokwiError(f"Server error: {message['message']}")
155172
if message["type"] == "response" and message.get("error"):
156-
result = (
157-
message["result"]
158-
if "result" in message
159-
else {"code": -1, "message": "Unknown error"}
160-
)
173+
result = message.get("result", {"code": -1, "message": "Unknown error"})
161174
raise WokwiError(f"Server error {result['code']}: {result['message']}")
162175
return cast(IncomingMessage, message)

0 commit comments

Comments
 (0)