Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/addtional fast apis for non-streaming simulation and managing relationshio #265

Merged
merged 15 commits into from
Dec 11, 2024
18 changes: 9 additions & 9 deletions examples/experimental/websocket/websocket_test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@


class WebSocketClient:
def __init__(self, uri: str, token: str, client_id: int):
self.uri = uri
def __init__(self, url: str, token: str, client_id: int):
self.url = url
self.token = token
self.client_id = client_id
self.message_file = Path(f"message_{client_id}.txt")
Expand All @@ -25,11 +25,11 @@ async def save_message(self, message: str) -> None:

async def connect(self) -> None:
"""Establish and maintain websocket connection"""
uri_with_token = f"{self.uri}?token=test_token_{self.client_id}"
url_with_token = f"{self.url}?token=test_token_{self.client_id}"

try:
async with websockets.connect(uri_with_token) as websocket:
print(f"Client {self.client_id}: Connected to {self.uri}")
async with websockets.connect(url_with_token) as websocket:
print(f"Client {self.client_id}: Connected to {self.url}")

# Send initial message
# Note: You'll need to implement the logic to get agent_ids and env_id
Expand Down Expand Up @@ -70,16 +70,16 @@ async def connect(self) -> None:
async def main() -> None:
# Create multiple WebSocket clients
num_clients = 0
uri = "ws://localhost:8800/ws/simulation"
url = "ws://localhost:8800/ws/simulation"

# Create and store client instances
clients = [
WebSocketClient(uri=uri, token=f"test_token_{i}", client_id=i)
WebSocketClient(url=url, token=f"test_token_{i}", client_id=i)
for i in range(num_clients)
]
clients.append(WebSocketClient(uri=uri, token="test_token_10", client_id=10))
clients.append(WebSocketClient(url=url, token="test_token_10", client_id=10))
clients.append(
WebSocketClient(uri=uri, token="test_token_10", client_id=10)
WebSocketClient(url=url, token="test_token_10", client_id=10)
) # test duplicate token

# Create tasks for each client
Expand Down
122 changes: 122 additions & 0 deletions examples/fast_api_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Example curl command to call the simulate endpoint:
import requests
import time

BASE_URL = "http://localhost:8080"


def _create_mock_agent_profile() -> None:
agent1_data = {
"first_name": "John",
"last_name": "Doe",
"occupation": "test_occupation",
"gender": "test_gender",
"pk": "tmppk_agent1",
"tag": "test_tag",
}
response = requests.post(
f"{BASE_URL}/agents/",
headers={"Content-Type": "application/json"},
json=agent1_data,
)
assert response.status_code == 200

agent2_data = {
"first_name": "Jane",
"last_name": "Doe",
"occupation": "test_occupation",
"gender": "test_gender",
"pk": "tmppk_agent2",
"tag": "test_tag",
}
response = requests.post(
f"{BASE_URL}/agents/",
headers={"Content-Type": "application/json"},
json=agent2_data,
)
assert response.status_code == 200


def _create_mock_env_profile() -> None:
env_data = {
"codename": "test_codename",
"scenario": "A",
"agent_goals": [
"B",
"C",
],
"pk": "tmppk_env_profile",
"tag": "test_tag",
}
response = requests.post(
f"{BASE_URL}/scenarios/",
headers={"Content-Type": "application/json"},
json=env_data,
)
assert response.status_code == 200


_create_mock_agent_profile()
_create_mock_env_profile()


data = {
"env_id": "tmppk_env_profile",
"agent_ids": ["tmppk_agent1", "tmppk_agent2"],
"models": ["custom/structured-llama3.2:1b@http://localhost:8000/v1"] * 3,
"max_turns": 10,
"tag": "test_tag",
}
try:
response = requests.post(
f"{BASE_URL}/simulate/", headers={"Content-Type": "application/json"}, json=data
)
print(response)
assert response.status_code == 202
assert isinstance(response.content.decode(), str)
episode_pk = response.content.decode()
print(episode_pk)
max_retries = 200
retry_count = 0
while retry_count < max_retries:
try:
response = requests.get(f"{BASE_URL}/simulation_status/{episode_pk}")
assert response.status_code == 200
status = response.content.decode()
print(status)
if status == "Error":
raise Exception("Error running simulation")
elif status == "Completed":
break
# Status is "Started", keep polling
time.sleep(1)
retry_count += 1
except Exception as e:
print(f"Error checking simulation status: {e}")
time.sleep(1)
retry_count += 1
else:
raise TimeoutError("Simulation timed out after 10 retries")

finally:
try:
response = requests.delete(f"{BASE_URL}/agents/tmppk_agent1")
assert response.status_code == 200
except Exception as e:
print(e)
try:
response = requests.delete(f"{BASE_URL}/agents/tmppk_agent2")
assert response.status_code == 200
except Exception as e:
print(e)
try:
response = requests.delete(f"{BASE_URL}/scenarios/tmppk_env_profile")
assert response.status_code == 200
except Exception as e:
print(e)

try:
response = requests.delete(f"{BASE_URL}/episodes/{episode_pk}")
assert response.status_code == 200
except Exception as e:
print(e)
4 changes: 3 additions & 1 deletion sotopia/database/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from redis_om import JsonModel, Migrator
from .annotators import Annotator
from .env_agent_combo_storage import EnvAgentComboStorage
from .logs import AnnotationForEpisode, EpisodeLog
from .logs import AnnotationForEpisode, EpisodeLog, NonStreamingSimulationStatus
from .persistent_profile import (
AgentProfile,
EnvironmentProfile,
Expand Down Expand Up @@ -44,6 +44,7 @@
"AgentProfile",
"EnvironmentProfile",
"EpisodeLog",
"NonStreamingSimulationStatus",
"EnvAgentComboStorage",
"AnnotationForEpisode",
"Annotator",
Expand Down Expand Up @@ -73,6 +74,7 @@
"EvaluationDimensionBuilder",
"CustomEvaluationDimension",
"CustomEvaluationDimensionList",
"NonStreamingSimulationStatus",
]

InheritedJsonModel = TypeVar("InheritedJsonModel", bound="JsonModel")
Expand Down
7 changes: 6 additions & 1 deletion sotopia/database/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from pydantic import model_validator
from redis_om import JsonModel
from redis_om.model.model import Field

from typing import Literal
from sotopia.database.persistent_profile import AgentProfile


class NonStreamingSimulationStatus(JsonModel):
episode_pk: str = Field(index=True)
status: Literal["Started", "Error", "Completed"]


class EpisodeLog(JsonModel):
# Note that we did not validate the following constraints:
# 1. The number of turns in messages and rewards should be the same or off by 1
Expand Down
5 changes: 5 additions & 0 deletions sotopia/database/persistent_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ class RelationshipProfile(JsonModel):
description="0 means stranger, 1 means know_by_name, 2 means acquaintance, 3 means friend, 4 means romantic_relationship, 5 means family_member",
) # this could be improved by limiting str to a relationship Enum
background_story: str | None = Field(default_factory=lambda: None)
tag: str = Field(
index=True,
default_factory=lambda: "",
description="The tag of the relationship, used for searching, could be convenient to document relationship profiles from different works and sources",
)


class EnvironmentList(JsonModel):
Expand Down
2 changes: 1 addition & 1 deletion sotopia/database/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _map_gender_to_adj(gender: str) -> str:
"Nonbinary": "nonbinary",
}
if gender:
return gender_to_adj[gender]
return gender_to_adj.get(gender, "")
else:
return ""

Expand Down
2 changes: 1 addition & 1 deletion sotopia/envs/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _map_gender_to_adj(gender: str) -> str:
"Nonbinary": "nonbinary",
}
if gender:
return gender_to_adj[gender]
return gender_to_adj.get(gender, "")
else:
return ""

Expand Down
16 changes: 13 additions & 3 deletions sotopia/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ScriptWritingAgent,
)
from sotopia.agents.base_agent import BaseAgent
from sotopia.database import EpisodeLog
from sotopia.database import EpisodeLog, NonStreamingSimulationStatus
from sotopia.envs import ParallelSotopiaEnv
from sotopia.envs.evaluators import (
EvaluationForTwoAgents,
Expand Down Expand Up @@ -119,12 +119,15 @@ async def arun_one_episode(
json_in_script: bool = False,
tag: str | None = None,
push_to_db: bool = False,
episode_pk: str | None = None,
streaming: bool = False,
simulation_status: NonStreamingSimulationStatus | None = None,
) -> Union[
list[tuple[str, str, Message]],
AsyncGenerator[list[list[tuple[str, str, Message]]], None],
]:
agents = Agents({agent.agent_name: agent for agent in agent_list})
print(f"Running episode with tag: {tag}------------------")

async def generate_messages() -> (
AsyncGenerator[list[list[tuple[str, str, Message]]], None]
Expand Down Expand Up @@ -188,7 +191,7 @@ async def generate_messages() -> (
for agent_name in env.agents
]
)

print(f"Messages: {messages}")
yield messages
rewards.append([rewards_in_turn[agent_name] for agent_name in env.agents])
reasons.append(
Expand Down Expand Up @@ -228,7 +231,14 @@ async def generate_messages() -> (

if push_to_db:
try:
epilog.save()
if episode_pk:
epilog.pk = episode_pk
epilog.save()
else:
epilog.save()
if simulation_status:
simulation_status.status = "Completed"
simulation_status.save()
except Exception as e:
logging.error(f"Failed to save episode log: {e}")

Expand Down
11 changes: 11 additions & 0 deletions sotopia/ui/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@

## FastAPI Server

To run the FastAPI server, you can use the following command:
```bash
uv run rq worker
uv run fastapi run sotopia/ui/fastapi_server.py --workers 4 --port 8080
```

Here is also an example of using the FastAPI server:
```bash
uv run python examples/fast_api_example.py
```

The API server is a FastAPI application that is used to connect the Sotopia UI to the Sotopia backend.
This could also help with other projects that need to connect to the Sotopia backend through HTTP requests.

Expand Down
Loading
Loading