Skip to content

Commit

Permalink
updated websocket protocol and server (#473)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Nov 17, 2023
1 parent 34b8b51 commit 85d0e0d
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
5 changes: 4 additions & 1 deletion memgpt/server/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
def condition_to_stop_receiving(response):
"""Determines when to stop listening to the server"""
return response.get("type") == "agent_response_end"
if response.get("type") in ["agent_response_end", "agent_response_error", "command_response", "server_error"]:
return True
else:
return False


def print_server_response(response):
Expand Down
3 changes: 2 additions & 1 deletion memgpt/server/websocket_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,12 @@ def server_agent_function_message(msg):
# Client -> server


def client_user_message(msg):
def client_user_message(msg, agent_name=None):
return json.dumps(
{
"type": "user_message",
"message": msg,
"agent_name": agent_name,
}
)

Expand Down
43 changes: 40 additions & 3 deletions memgpt/server/websocket_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import traceback

import websockets

Expand All @@ -15,7 +16,9 @@ def __init__(self, host="localhost", port=DEFAULT_PORT):
self.host = host
self.port = port
self.interface = SyncWebSocketInterface()

self.agent = None
self.agent_name = None

def run_step(self, user_message, first_message=False, no_verify=False):
while True:
Expand All @@ -41,16 +44,27 @@ async def handle_client(self, websocket, path):
message = await websocket.recv()

# Assuming the message is a JSON string
data = json.loads(message)

if data["type"] == "command":
try:
data = json.loads(message)
except:
print(f"[server] bad data from client:\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))
continue

if "type" not in data:
print(f"[server] bad data from client (JSON but no type):\n{data}")
await websocket.send(protocol.server_command_response(f"Error: bad data from client - {str(data)}"))

elif data["type"] == "command":
# Create a new agent
if data["command"] == "create_agent":
try:
self.agent = self.create_new_agent(data["config"])
await websocket.send(protocol.server_command_response("OK: Agent initialized"))
except Exception as e:
self.agent = None
print(f"[server] self.create_new_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_command_response(f"Error: Failed to init agent - {str(e)}"))

# Load an existing agent
Expand All @@ -59,9 +73,11 @@ async def handle_client(self, websocket, path):
if agent_name is not None:
try:
self.agent = self.load_agent(agent_name)
self.agent_name = agent_name
await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded"))
except Exception as e:
print(f"[server] self.load_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
self.agent = None
await websocket.send(
protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}")
Expand All @@ -76,6 +92,26 @@ async def handle_client(self, websocket, path):
elif data["type"] == "user_message":
user_message = data["message"]

if "agent_name" in data:
agent_name = data["agent_name"]
# If the agent requested the same one that's already loading?
if self.agent_name is None or self.agent_name != data["agent_name"]:
try:
print(f"[server] loading agent {agent_name}")
self.agent = self.load_agent(agent_name)
self.agent_name = agent_name
# await websocket.send(protocol.server_command_response(f"OK: Agent '{agent_name}' loaded"))
except Exception as e:
print(f"[server] self.load_agent failed with:\n{e}")
print(f"{traceback.format_exc()}")
self.agent = None
await websocket.send(
protocol.server_command_response(f"Error: Failed to load agent '{agent_name}' - {str(e)}")
)
else:
await websocket.send(protocol.server_agent_response_error("agent_name was not specified in the request"))
continue

if self.agent is None:
await websocket.send(protocol.server_agent_response_error("No agent has been initialized"))
else:
Expand All @@ -84,6 +120,7 @@ async def handle_client(self, websocket, path):
self.run_step(user_message)
except Exception as e:
print(f"[server] self.run_step failed with:\n{e}")
print(f"{traceback.format_exc()}")
await websocket.send(protocol.server_agent_response_error(f"self.run_step failed with: {e}"))

await asyncio.sleep(1) # pause before sending the terminating message, w/o this messages may be missed
Expand Down

0 comments on commit 85d0e0d

Please sign in to comment.