From 008763aab058e255104b6b5de2c680f628ec379e Mon Sep 17 00:00:00 2001 From: Robert Brisita <986796+rbrisita@users.noreply.github.com> Date: Mon, 15 Apr 2024 20:13:54 -0400 Subject: [PATCH] Fixing asyncio queue creation and usage by decoupling app, queues, and uvicorn config. --- software/source/server/app.py | 110 ++++++++++++++++++++++++ software/source/server/queues.py | 43 ++++++++++ software/source/server/server.py | 138 ++++--------------------------- 3 files changed, 168 insertions(+), 123 deletions(-) create mode 100644 software/source/server/app.py create mode 100644 software/source/server/queues.py diff --git a/software/source/server/app.py b/software/source/server/app.py new file mode 100644 index 00000000..130a2601 --- /dev/null +++ b/software/source/server/app.py @@ -0,0 +1,110 @@ +from fastapi import FastAPI, Request +from fastapi.responses import PlainTextResponse +from starlette.websockets import WebSocket, WebSocketDisconnect +import asyncio +from .utils.logs import setup_logging +from .utils.logs import logger +import traceback +import json +from ..utils.print_markdown import print_markdown +from .queues import Queues + + +setup_logging() + +app = FastAPI() + +from_computer, from_user, to_device = Queues.get() + + +@app.get("/ping") +async def ping(): + return PlainTextResponse("pong") + + +@app.websocket("/") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + receive_task = asyncio.create_task(receive_messages(websocket)) + send_task = asyncio.create_task(send_messages(websocket)) + try: + await asyncio.gather(receive_task, send_task) + except Exception as e: + logger.debug(traceback.format_exc()) + logger.info(f"Connection lost. Error: {e}") + + +@app.post("/") +async def add_computer_message(request: Request): + body = await request.json() + text = body.get("text") + if not text: + return {"error": "Missing 'text' in request body"}, 422 + message = {"role": "user", "type": "message", "content": text} + await from_user.put({"role": "user", "type": "message", "start": True}) + await from_user.put(message) + await from_user.put({"role": "user", "type": "message", "end": True}) + + +async def receive_messages(websocket: WebSocket): + while True: + try: + try: + data = await websocket.receive() + except Exception as e: + print(str(e)) + return + + if "text" in data: + try: + data = json.loads(data["text"]) + if data["role"] == "computer": + from_computer.put( + data + ) # To be handled by interpreter.computer.run + elif data["role"] == "user": + await from_user.put(data) + else: + raise ("Unknown role:", data) + except json.JSONDecodeError: + pass # data is not JSON, leave it as is + elif "bytes" in data: + data = data["bytes"] # binary data + await from_user.put(data) + except WebSocketDisconnect as e: + if e.code == 1000: + logger.info("Websocket connection closed normally.") + return + else: + raise + + +async def send_messages(websocket: WebSocket): + while True: + try: + message = await to_device.get() + # print(f"Sending to the device: {type(message)} {str(message)[:100]}") + + if isinstance(message, dict): + await websocket.send_json(message) + elif isinstance(message, bytes): + await websocket.send_bytes(message) + else: + raise TypeError("Message must be a dict or bytes") + except Exception as e: + if message: + # Make sure to put the message back in the queue if you failed to send it + await to_device.put(message) + raise + +# TODO: These two methods should change to lifespan +@app.on_event("startup") +async def startup_event(): + print("") + print_markdown("\n*Ready.*\n") + print("") + + +@app.on_event("shutdown") +async def shutdown_event(): + print_markdown("*Server is shutting down*") diff --git a/software/source/server/queues.py b/software/source/server/queues.py new file mode 100644 index 00000000..42c5435c --- /dev/null +++ b/software/source/server/queues.py @@ -0,0 +1,43 @@ +import asyncio +import queue + +''' +Queues are created on demand and should +be accessed inside the currect event loop +from a asyncio.run(co()) call. +''' + +class _ReadOnly(type): + @property + def from_computer(cls): + if not cls._from_computer: + # Sync queue because interpreter.run is synchronous. + cls._from_computer = queue.Queue() + return cls._from_computer + + @property + def from_user(cls): + if not cls._from_user: + cls._from_user = asyncio.Queue() + return cls._from_user + + @property + def to_device(cls): + if not cls._to_device: + cls._to_device = asyncio.Queue() + return cls._to_device + + +class Queues(metaclass=_ReadOnly): + # Queues used in server and app + # Just for computer messages from the device. + _from_computer = None + + # Just for user messages from the device. + _from_user = None + + # For messages we send. + _to_device = None + + def get(): + return Queues.from_computer, Queues.from_user, Queues.to_device diff --git a/software/source/server/server.py b/software/source/server/server.py index c4dd0367..63b9f275 100644 --- a/software/source/server/server.py +++ b/software/source/server/server.py @@ -5,14 +5,10 @@ import traceback from platformdirs import user_data_dir import json -import queue import os import datetime from .utils.bytes_to_wav import bytes_to_wav import re -from fastapi import FastAPI, Request -from fastapi.responses import PlainTextResponse -from starlette.websockets import WebSocket, WebSocketDisconnect import asyncio from .utils.kernel import put_kernel_messages_into_queue from .i import configure_interpreter @@ -20,8 +16,8 @@ from ..utils.accumulator import Accumulator from .utils.logs import setup_logging from .utils.logs import logger - from ..utils.print_markdown import print_markdown +from .queues import Queues os.environ["STT_RUNNER"] = "server" os.environ["TTS_RUNNER"] = "server" @@ -40,8 +36,6 @@ accumulator = Accumulator() -app = FastAPI() - app_dir = user_data_dir("01") conversation_history_path = os.path.join(app_dir, "conversations", "user.json") @@ -56,14 +50,6 @@ def is_full_sentence(text): def split_into_sentences(text): return re.split(r"(?<=[.!?])\s+", text) - -# Queues -from_computer = ( - queue.Queue() -) # Just for computer messages from the device. Sync queue because interpreter.run is synchronous -from_user = asyncio.Queue() # Just for user messages from the device. -to_device = asyncio.Queue() # For messages we send. - # Switch code executor to device if that's set if os.getenv("CODE_RUNNER") == "device": @@ -76,9 +62,11 @@ class Python: def __init__(self): self.halt = False - def run(self, code): + async def run(self, code): """Generator that yields a dictionary in LMC Format.""" + from_computer, _, to_device = Queues.get() + # Prepare the data message = { "role": "assistant", @@ -89,7 +77,7 @@ def run(self, code): # Unless it was just sent to the device, send it wrapped in flags if not (interpreter.messages and interpreter.messages[-1] == message): - to_device.put( + await to_device.put( { "role": "assistant", "type": "code", @@ -97,8 +85,8 @@ def run(self, code): "start": True, } ) - to_device.put(message) - to_device.put( + await to_device.put(message) + await to_device.put( { "role": "assistant", "type": "code", @@ -130,86 +118,9 @@ def terminate(self): interpreter = configure_interpreter(interpreter) -@app.get("/ping") -async def ping(): - return PlainTextResponse("pong") - - -@app.websocket("/") -async def websocket_endpoint(websocket: WebSocket): - await websocket.accept() - receive_task = asyncio.create_task(receive_messages(websocket)) - send_task = asyncio.create_task(send_messages(websocket)) - try: - await asyncio.gather(receive_task, send_task) - except Exception as e: - logger.debug(traceback.format_exc()) - logger.info(f"Connection lost. Error: {e}") - - -@app.post("/") -async def add_computer_message(request: Request): - body = await request.json() - text = body.get("text") - if not text: - return {"error": "Missing 'text' in request body"}, 422 - message = {"role": "user", "type": "message", "content": text} - await from_user.put({"role": "user", "type": "message", "start": True}) - await from_user.put(message) - await from_user.put({"role": "user", "type": "message", "end": True}) - - -async def receive_messages(websocket: WebSocket): - while True: - try: - try: - data = await websocket.receive() - except Exception as e: - print(str(e)) - return - if "text" in data: - try: - data = json.loads(data["text"]) - if data["role"] == "computer": - from_computer.put( - data - ) # To be handled by interpreter.computer.run - elif data["role"] == "user": - await from_user.put(data) - else: - raise ("Unknown role:", data) - except json.JSONDecodeError: - pass # data is not JSON, leave it as is - elif "bytes" in data: - data = data["bytes"] # binary data - await from_user.put(data) - except WebSocketDisconnect as e: - if e.code == 1000: - logger.info("Websocket connection closed normally.") - return - else: - raise - - -async def send_messages(websocket: WebSocket): - while True: - message = await to_device.get() - # print(f"Sending to the device: {type(message)} {str(message)[:100]}") - - try: - if isinstance(message, dict): - await websocket.send_json(message) - elif isinstance(message, bytes): - await websocket.send_bytes(message) - else: - raise TypeError("Message must be a dict or bytes") - except: - # Make sure to put the message back in the queue if you failed to send it - await to_device.put(message) - raise - - async def listener(): + from_computer, from_user, to_device = Queues.get() + while True: try: while True: @@ -250,6 +161,7 @@ async def listener(): time.sleep(15) + # stt is a bound method text = stt(audio_file_path) print("> ", text) message = {"role": "user", "type": "message", "content": text} @@ -367,10 +279,11 @@ async def stream_tts_to_device(sentence): return for chunk in stream_tts(sentence): - await to_device.put(chunk) + await Queues.to_device.put(chunk) def stream_tts(sentence): + # tts is a bound method audio_file = tts(sentence) with open(audio_file, "rb") as f: @@ -392,23 +305,6 @@ def stream_tts(sentence): import os from importlib import import_module -# these will be overwritten -HOST = "" -PORT = 0 - - -@app.on_event("startup") -async def startup_event(): - server_url = f"{HOST}:{PORT}" - print("") - print_markdown("\n*Ready.*\n") - print("") - - -@app.on_event("shutdown") -async def shutdown_event(): - print_markdown("*Server is shutting down*") - async def main( server_host, @@ -423,11 +319,6 @@ async def main( tts_service, stt_service, ): - global HOST - global PORT - PORT = server_port - HOST = server_host - # Setup services application_directory = user_data_dir("01") services_directory = os.path.join(application_directory, "services") @@ -470,6 +361,7 @@ async def main( service_instance = ServiceClass(config) globals()[service] = getattr(service_instance, service) + # llm is a bound method interpreter.llm.completions = llm # Start listening @@ -477,9 +369,9 @@ async def main( # Start watching the kernel if it's your job to do that if True: # in the future, code can run on device. for now, just server. - asyncio.create_task(put_kernel_messages_into_queue(from_computer)) + asyncio.create_task(put_kernel_messages_into_queue(Queues.from_computer)) - config = Config(app, host=server_host, port=int(server_port), lifespan="on") + config = Config("source.server.app:app", host=server_host, port=int(server_port), lifespan="on") server = Server(config) await server.serve()