Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 71 additions & 66 deletions src/wokwi_client/client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,34 +39,31 @@ class WokwiClientSync:
tracked, so we can cancel & drain them on `disconnect()`.
"""

# Public attributes mirrored for convenience
version: str
last_pause_nanos: int # this proxy resolves via __getattr__

def __init__(self, token: str, server: str | None = None):
# Create a fresh event loop + thread (daemon so it won't prevent process exit).
# Create a new event loop for the background thread
self._loop = asyncio.new_event_loop()
# Event to signal that the event loop is running
self._loop_started_event = threading.Event()
# Start background thread running the event loop
self._thread = threading.Thread(
target=self._run_loop, args=(self._loop,), daemon=True, name="wokwi-sync-loop"
)
self._thread.start()

# Underlying async client
# **Wait until loop is fully started before proceeding** (prevents race conditions)
if not self._loop_started_event.wait(timeout=8.0): # timeout to avoid deadlock
raise RuntimeError("WokwiClientSync event loop failed to start")
# Initialize underlying async client on the running loop
self._async_client = WokwiClient(token, server)

# Mirror library version for quick access
self.version = self._async_client.version

# Track background tasks created via run_coroutine_threadsafe (serial monitors)
# Track background monitor tasks (futures) for cancellation on exit
self._bg_futures: set[Future[Any]] = set()

# Idempotent disconnect guard
# Flag to avoid double-closing
self._closed = False

@staticmethod
def _run_loop(loop: asyncio.AbstractEventLoop) -> None:
"""Background thread loop runner."""
def _run_loop(self, loop: asyncio.AbstractEventLoop) -> None:
"""Target function for the background thread: runs the asyncio event loop."""
asyncio.set_event_loop(loop)
# Signal that the loop is now running and ready to accept tasks
loop.call_soon(self._loop_started_event.set)
loop.run_forever()

# ----- Internal helpers -------------------------------------------------
Expand All @@ -75,8 +72,11 @@ def _submit(self, coro: Coroutine[Any, Any, T]) -> Future[T]:
return asyncio.run_coroutine_threadsafe(coro, self._loop)

def _call(self, coro: Coroutine[Any, Any, T]) -> T:
"""Submit a coroutine to the loop and block until it completes (or raises)."""
return self._submit(coro).result()
"""Submit a coroutine to the background loop and wait for result."""
if self._closed:
raise RuntimeError("Cannot call methods on a closed WokwiClientSync")
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result() # Block until the coroutine completes or raises

def _add_bg_future(self, fut: Future[Any]) -> None:
"""Track a background future so we can cancel & drain on shutdown."""
Expand All @@ -96,37 +96,35 @@ def connect(self) -> dict[str, Any]:
return self._call(self._async_client.connect())

def disconnect(self) -> None:
"""Disconnect and stop the background loop.

Order matters:
1) Cancel and drain background serial-monitor futures.
2) Disconnect the underlying transport.
3) Stop the loop and join the thread.
Safe to call multiple times.
"""
if self._closed:
return
self._closed = True

# (1) Cancel + drain monitors
for fut in list(self._bg_futures):
fut.cancel()
for fut in list(self._bg_futures):
with contextlib.suppress(FutureTimeoutError, Exception):
# Give each monitor a short window to handle cancellation cleanly.
fut.result(timeout=1.0)
self._bg_futures.discard(fut)

# (2) Disconnect transport
with contextlib.suppress(Exception):
self._call(self._async_client._transport.close())
fut = asyncio.run_coroutine_threadsafe(self._async_client.disconnect(), self._loop)
fut.result(timeout=2.0)

# (3) Stop loop / join thread
if self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
if self._thread.is_alive():
self._thread.join(timeout=5.0)

# (4) Close loop
with contextlib.suppress(Exception):
self._loop.close()

# (5) Mark closed at the very end
self._closed = True

# ----- Serial monitoring ------------------------------------------------
def serial_monitor(self, callback: Callable[[bytes], Any]) -> None:
"""
Expand All @@ -138,17 +136,25 @@ def serial_monitor(self, callback: Callable[[bytes], Any]) -> None:
"""

async def _runner() -> None:
async for line in monitor_lines(self._async_client._transport):
try:
maybe_awaitable = callback(line)
if inspect.isawaitable(maybe_awaitable):
await maybe_awaitable
except Exception:
# Keep the monitor alive even if the callback throws.
pass

fut = self._submit(_runner())
self._add_bg_future(fut)
try:
# **Prepare to receive serial events before enabling monitor**
# (monitor_lines will subscribe to serial events internally)
async for line in monitor_lines(self._async_client._transport):
try:
result = callback(line) # invoke callback with the raw bytes line
if inspect.isawaitable(result):
await result # await if callback is async
except Exception:
# Swallow exceptions from callback to keep monitor alive
pass
finally:
# Remove this task’s future from the set when done
self._bg_futures.discard(task_future)

# Schedule the serial monitor runner on the event loop:
task_future = asyncio.run_coroutine_threadsafe(_runner(), self._loop)
self._bg_futures.add(task_future)
# (No return value; monitoring happens in background)

def serial_monitor_cat(self, decode_utf8: bool = True, errors: str = "replace") -> None:
"""
Expand All @@ -160,34 +166,32 @@ def serial_monitor_cat(self, decode_utf8: bool = True, errors: str = "replace")
"""

async def _runner() -> None:
async for line in monitor_lines(self._async_client._transport):
try:
if decode_utf8:
try:
print(line.decode("utf-8", errors=errors), end="", flush=True)
except UnicodeDecodeError:
try:
# **Subscribe to serial events before reading output**
async for line in monitor_lines(self._async_client._transport):
try:
if decode_utf8:
# Decode bytes to string (handle errors per parameter)
text = line.decode("utf-8", errors=errors)
print(text, end="", flush=True)
else:
# Print raw bytes
print(line, end="", flush=True)
else:
print(line, end="", flush=True)
except Exception:
# Keep the monitor alive even if printing raises intermittently.
pass
except Exception:
# Swallow print errors to keep stream alive
pass
finally:
self._bg_futures.discard(task_future)

fut = self._submit(_runner())
self._add_bg_future(fut)
task_future = asyncio.run_coroutine_threadsafe(_runner(), self._loop)
self._bg_futures.add(task_future)
# (No return; printing continues in background)

def stop_serial_monitors(self) -> None:
"""
Cancel and drain all running serial monitors without disconnecting.

Useful if you want to stop printing but keep the connection alive.
"""
"""Stop all active serial monitor background tasks."""
for fut in list(self._bg_futures):
fut.cancel()
for fut in list(self._bg_futures):
with contextlib.suppress(FutureTimeoutError, Exception):
fut.result(timeout=1.0)
self._bg_futures.discard(fut)
self._bg_futures.clear()

# ----- Dynamic method wrapping -----------------------------------------
def __getattr__(self, name: str) -> Any:
Expand All @@ -197,16 +201,17 @@ def __getattr__(self, name: str) -> Any:
If the attribute on `WokwiClient` is a coroutine function, return a
sync wrapper that blocks until the coroutine completes.
"""
# Explicit methods above (serial monitors) take precedence.
# Explicit methods (like serial_monitor functions above) take precedence over __getattr__
attr = getattr(self._async_client, name)
if callable(attr):
# Get the function object from WokwiClient class (unbound) to check if coroutine
func = getattr(WokwiClient, name, None)
if func is not None and inspect.iscoroutinefunction(func):

# Wrap coroutine method to run in background loop
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
return self._call(attr(*args, **kwargs))

sync_wrapper.__name__ = name
sync_wrapper.__doc__ = func.__doc__
sync_wrapper.__doc__ = getattr(func, "__doc__", "")
return sync_wrapper
return attr
3 changes: 3 additions & 0 deletions src/wokwi_client/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@


async def monitor_lines(transport: Transport) -> AsyncGenerator[bytes, None]:
"""
Monitor the serial output lines.
"""
await transport.request("serial-monitor:listen", {})
with EventQueue(transport, "serial-monitor:data") as queue:
while True:
Expand Down
71 changes: 42 additions & 29 deletions src/wokwi_client/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: MIT

import asyncio
import contextlib
import json
import os
import warnings
Expand Down Expand Up @@ -48,31 +49,27 @@ async def connect(self, throw_error: bool = True) -> dict[str, Any]:
raise ProtocolError(f"Unsupported protocol handshake: {hello}")
hello_msg = cast(HelloMessage, hello)
self._closed = False
# Start background message processor
# Start background message processor after successful hello.
self._recv_task = asyncio.create_task(self._background_recv(throw_error))
return {"version": hello_msg["appVersion"]}

async def close(self) -> None:
self._closed = True
if self._recv_task:
self._recv_task.cancel()
try:
with contextlib.suppress(asyncio.CancelledError):
await self._recv_task
except asyncio.CancelledError:
pass
if self._ws:
await self._ws.close()

def add_event_listener(self, event_type: str, listener: Callable[[EventMessage], Any]) -> None:
"""Register a listener for a specific event type."""
if event_type not in self._event_listeners:
self._event_listeners[event_type] = []
self._event_listeners[event_type].append(listener)

def remove_event_listener(
self, event_type: str, listener: Callable[[EventMessage], Any]
) -> None:
"""Remove a previously registered listener for a specific event type."""
if event_type in self._event_listeners:
self._event_listeners[event_type] = [
registered_listener
Expand All @@ -90,52 +87,72 @@ async def _dispatch_event(self, event_msg: EventMessage) -> None:
await result

async def request(self, command: str, params: dict[str, Any]) -> ResponseMessage:
msg_id = str(self._next_id)
self._next_id += 1
if self._ws is None:
raise WokwiError("Not connected")
msg_id = str(self._next_id)
self._next_id += 1

loop = asyncio.get_running_loop()
future: asyncio.Future[ResponseMessage] = loop.create_future()
self._response_futures[msg_id] = future

await self._ws.send(
json.dumps({"type": "command", "command": command, "params": params, "id": msg_id})
)
try:
resp_msg_resp = await future
if resp_msg_resp.get("error"):
result = resp_msg_resp["result"]
raise ServerError(result["message"])
result = resp_msg_resp.get("result", {})
raise ServerError(result.get("message", "Unknown server error"))
return resp_msg_resp
finally:
del self._response_futures[msg_id]
# Remove future mapping if still present (be defensive)
self._response_futures.pop(msg_id, None)

async def _background_recv(self, throw_error: bool = True) -> None:
async def _background_recv(self, throw_error: bool = True) -> None: # noqa: PLR0912
try:
while not self._closed and self._ws is not None:
msg: IncomingMessage = await self._recv()
if msg["type"] == MSG_TYPE_EVENT:
resp_msg_event = cast(EventMessage, msg)
await self._dispatch_event(resp_msg_event)
await self._dispatch_event(cast(EventMessage, msg))
elif msg["type"] == MSG_TYPE_RESPONSE:
resp_msg_resp = cast(ResponseMessage, msg)
future = self._response_futures.get(resp_msg_resp["id"])
resp_id = str(resp_msg_resp.get("id"))
future = self._response_futures.get(resp_id)
if future is None or future.done():
continue
future.set_result(resp_msg_resp)
except (websockets.ConnectionClosed, asyncio.CancelledError):
pass
except asyncio.CancelledError:
# Expected during shutdown via close()
raise
except websockets.ConnectionClosed as e:
# Mark closed and fail pending futures to avoid hangs.
self._closed = True
for fut in list(self._response_futures.values()):
if not fut.done():
fut.set_exception(e)
with contextlib.suppress(Exception):
if self._ws:
await self._ws.close()
if throw_error:
raise
except Exception as e:
warnings.warn(f"Background recv error: {e}", RuntimeWarning)

if throw_error:
self._closed = True
# Cancel all pending response futures
for future in self._response_futures.values():
if not future.done():
future.set_exception(e)
if self._ws:
await self._ws.close()
for fut in list(self._response_futures.values()):
if not fut.done():
fut.set_exception(e)
with contextlib.suppress(Exception):
if self._ws:
await self._ws.close()
raise
finally:
# If we’re exiting the loop and marked closed, ensure no future hangs.
if self._closed:
for fut in list(self._response_futures.values()):
if not fut.done():
fut.set_exception(RuntimeError("Transport receive loop exited"))

async def _recv(self) -> IncomingMessage:
if self._ws is None:
Expand All @@ -153,10 +170,6 @@ async def _recv(self) -> IncomingMessage:
if message["type"] == "error":
raise WokwiError(f"Server error: {message['message']}")
if message["type"] == "response" and message.get("error"):
result = (
message["result"]
if "result" in message
else {"code": -1, "message": "Unknown error"}
)
result = message.get("result", {"code": -1, "message": "Unknown error"})
raise WokwiError(f"Server error {result['code']}: {result['message']}")
return cast(IncomingMessage, message)