Skip to content

Commit

Permalink
feat: FastAPI Implementation of Sotopia Part Two (w websocket) (#252)
Browse files Browse the repository at this point in the history
* api doc

* add PUT

* add an temp example for websocket

* websocket

* update readme

* Update README.md

* update websocket live simulation api doc

* [autofix.ci] apply automated fixes

* update websocket doc

* add api server with websocket as well as a client

* fix mypy errors

* support stopping the chat

* add 404 to the status code

* fix mypy issue

* update the returned message types

* redesign websocket api

* update websocket, fix mypy error

* add example of using websocket

* clean code & change to existing functions for simulation

* fix typing mismatch

* update doc & mypy type fix

* add type check for run_async_server

* move example

---------

Co-authored-by: Hao Zhu <prokilchu@gmail.com>
Co-authored-by: Zhe Su <360307598@qq.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Dec 5, 2024
1 parent 61f190e commit 693f792
Show file tree
Hide file tree
Showing 6 changed files with 754 additions and 123 deletions.
97 changes: 97 additions & 0 deletions examples/experimental/websocket/websocket_test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
A test client for the WebSocket server
"""

import json
from sotopia.database import EnvironmentProfile, AgentProfile

import asyncio
import websockets
import sys
from pathlib import Path


class WebSocketClient:
def __init__(self, uri: str, token: str, client_id: int):
self.uri = uri
self.token = token
self.client_id = client_id
self.message_file = Path(f"message_{client_id}.txt")

async def save_message(self, message: str) -> None:
"""Save received message to a file"""
with open(self.message_file, "a", encoding="utf-8") as f:
f.write(f"{message}\n")

async def connect(self) -> None:
"""Establish and maintain websocket connection"""
uri_with_token = f"{self.uri}?token=test_token_{self.client_id}"

try:
async with websockets.connect(uri_with_token) as websocket:
print(f"Client {self.client_id}: Connected to {self.uri}")

# Send initial message
# Note: You'll need to implement the logic to get agent_ids and env_id
# This is just an example structure
agent_ids = [agent.pk for agent in AgentProfile.find().all()[:2]]
env_id = EnvironmentProfile.find().all()[0].pk
start_message = {
"type": "START_SIM",
"data": {
"env_id": env_id, # Replace with actual env_id
"agent_ids": agent_ids, # Replace with actual agent_ids
},
}
await websocket.send(json.dumps(start_message))
print(f"Client {self.client_id}: Sent START_SIM message")

# Receive and process messages
while True:
try:
message = await websocket.recv()
print(
f"\nClient {self.client_id} received message:",
json.dumps(json.loads(message), indent=2),
)
assert isinstance(message, str)
await self.save_message(message)
except websockets.ConnectionClosed:
print(f"Client {self.client_id}: Connection closed")
break
except Exception as e:
print(f"Client {self.client_id} error:", str(e))
break

except Exception as e:
print(f"Client {self.client_id} connection error:", str(e))


async def main() -> None:
# Create multiple WebSocket clients
num_clients = 0
uri = "ws://localhost:8800/ws/simulation"

# Create and store client instances
clients = [
WebSocketClient(uri=uri, token=f"test_token_{i}", client_id=i)
for i in range(num_clients)
]
clients.append(WebSocketClient(uri=uri, token="test_token_10", client_id=10))
clients.append(
WebSocketClient(uri=uri, token="test_token_10", client_id=10)
) # test duplicate token

# Create tasks for each client
tasks = [asyncio.create_task(client.connect()) for client in clients]

# Wait for all tasks to complete
await asyncio.gather(*tasks)


if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\nShutting down clients...")
sys.exit(0)
209 changes: 122 additions & 87 deletions sotopia/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import itertools
import logging
from typing import Literal, Sequence, Type
from typing import Literal, Sequence, Type, AsyncGenerator, Union

import gin
import rich
Expand All @@ -25,7 +25,7 @@
unweighted_aggregate_evaluate,
)
from sotopia.generation_utils.generate import LLM_Name, agenerate_script
from sotopia.messages import AgentAction, Message, Observation
from sotopia.messages import AgentAction, Message, Observation, SimpleMessage
from sotopia.messages.message_classes import (
ScriptBackground,
ScriptEnvironmentResponse,
Expand Down Expand Up @@ -104,6 +104,12 @@ def run_sync_server(
return messages


def flatten_listed_messages(
messages: list[list[tuple[str, str, Message]]],
) -> list[tuple[str, str, Message]]:
return list(itertools.chain.from_iterable(messages))


@gin.configurable
async def arun_one_episode(
env: ParallelSotopiaEnv,
Expand All @@ -113,102 +119,125 @@ async def arun_one_episode(
json_in_script: bool = False,
tag: str | None = None,
push_to_db: bool = False,
) -> list[tuple[str, str, Message]]:
streaming: bool = False,
) -> Union[
list[tuple[str, str, Message]],
AsyncGenerator[list[list[tuple[str, str, Message]]], None],
]:
agents = Agents({agent.agent_name: agent for agent in agent_list})
environment_messages = env.reset(agents=agents, omniscient=omniscient)
agents.reset()

messages: list[list[tuple[str, str, Message]]] = []

# Main Event Loop
done = False
messages.append(
[
("Environment", agent_name, environment_messages[agent_name])
for agent_name in env.agents
]
)
# set goal for agents
for index, agent_name in enumerate(env.agents):
agents[agent_name].goal = env.profile.agent_goals[index]
rewards: list[list[float]] = []
reasons: list[str] = []
while not done:
# gather agent messages
agent_messages: dict[str, AgentAction] = dict()
actions = await asyncio.gather(
*[
agents[agent_name].aact(environment_messages[agent_name])
for agent_name in env.agents
]
)
if script_like:
# manually mask one message
agent_mask = env.action_mask
for idx in range(len(agent_mask)):
print("Current mask: ", agent_mask)
if agent_mask[idx] == 0:
print("Action not taken: ", actions[idx])
actions[idx] = AgentAction(action_type="none", argument="")
else:
print("Current action taken: ", actions[idx])

# actions = cast(list[AgentAction], actions)
for idx, agent_name in enumerate(env.agents):
agent_messages[agent_name] = actions[idx]

messages[-1].append((agent_name, "Environment", agent_messages[agent_name]))
async def generate_messages() -> (
AsyncGenerator[list[list[tuple[str, str, Message]]], None]
):
environment_messages = env.reset(agents=agents, omniscient=omniscient)
agents.reset()
messages: list[list[tuple[str, str, Message]]] = []

# send agent messages to environment
(
environment_messages,
rewards_in_turn,
terminated,
___,
info,
) = await env.astep(agent_messages)
# Main Event Loop
done = False
messages.append(
[
("Environment", agent_name, environment_messages[agent_name])
for agent_name in env.agents
]
)
# print("Environment message: ", environment_messages)
# exit(0)
rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents])
reasons.append(
" ".join(info[agent_name]["comments"] for agent_name in env.agents)
yield messages

# set goal for agents
for index, agent_name in enumerate(env.agents):
agents[agent_name].goal = env.profile.agent_goals[index]
rewards: list[list[float]] = []
reasons: list[str] = []
while not done:
# gather agent messages
agent_messages: dict[str, AgentAction] = dict()
actions = await asyncio.gather(
*[
agents[agent_name].aact(environment_messages[agent_name])
for agent_name in env.agents
]
)
if script_like:
# manually mask one message
agent_mask = env.action_mask
for idx in range(len(agent_mask)):
if agent_mask[idx] == 0:
actions[idx] = AgentAction(action_type="none", argument="")
else:
pass

# actions = cast(list[AgentAction], actions)
for idx, agent_name in enumerate(env.agents):
agent_messages[agent_name] = actions[idx]

messages[-1].append(
(agent_name, "Environment", agent_messages[agent_name])
)

# send agent messages to environment
(
environment_messages,
rewards_in_turn,
terminated,
___,
info,
) = await env.astep(agent_messages)
messages.append(
[
("Environment", agent_name, environment_messages[agent_name])
for agent_name in env.agents
]
)

yield messages
rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents])
reasons.append(
" ".join(info[agent_name]["comments"] for agent_name in env.agents)
)
done = all(terminated.values())

epilog = EpisodeLog(
environment=env.profile.pk,
agents=[agent.profile.pk for agent in agent_list],
tag=tag,
models=[env.model_name, agent_list[0].model_name, agent_list[1].model_name],
messages=[
[(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn]
for messages_in_turn in messages
],
reasoning=info[env.agents[0]]["comments"],
rewards=[info[agent_name]["complete_rating"] for agent_name in env.agents],
rewards_prompt=info["rewards_prompt"]["overall_prompt"],
)
done = all(terminated.values())
rich.print(epilog.rewards_prompt)
agent_profiles, conversation = epilog.render_for_humans()
for agent_profile in agent_profiles:
rich.print(agent_profile)
for message in conversation:
rich.print(message)

if streaming:
# yield the rewards and reasonings
messages.append(
[("Evaluation", "Rewards", SimpleMessage(message=str(epilog.rewards)))]
)
messages.append(
[("Evaluation", "Reasoning", SimpleMessage(message=epilog.reasoning))]
)
yield messages

# TODO: clean up this part
epilog = EpisodeLog(
environment=env.profile.pk,
agents=[agent.profile.pk for agent in agent_list],
tag=tag,
models=[env.model_name, agent_list[0].model_name, agent_list[1].model_name],
messages=[
[(m[0], m[1], m[2].to_natural_language()) for m in messages_in_turn]
for messages_in_turn in messages
],
reasoning=info[env.agents[0]]["comments"],
rewards=[info[agent_name]["complete_rating"] for agent_name in env.agents],
rewards_prompt=info["rewards_prompt"]["overall_prompt"],
)
rich.print(epilog.rewards_prompt)
agent_profiles, conversation = epilog.render_for_humans()
for agent_profile in agent_profiles:
rich.print(agent_profile)
for message in conversation:
rich.print(message)
if push_to_db:
try:
epilog.save()
except Exception as e:
logging.error(f"Failed to save episode log: {e}")

if push_to_db:
try:
epilog.save()
except Exception as e:
logging.error(f"Failed to save episode log: {e}")
# flatten nested list messages
return list(itertools.chain(*messages))
if streaming:
return generate_messages()
else:
async for last_messages in generate_messages():
pass
return flatten_listed_messages(last_messages)


@gin.configurable
Expand Down Expand Up @@ -310,7 +339,13 @@ def get_agent_class(
else [await i for i in episode_futures]
)

return batch_results
if len(batch_results) > 0:
first_result = batch_results[0]
assert isinstance(
first_result, list
), f"Unexpected result type: {type(first_result)}"

return batch_results # type: ignore


async def arun_one_script(
Expand Down
Loading

0 comments on commit 693f792

Please sign in to comment.