Skip to content

Commit fadd43e

Browse files
authored
feature(example): add a new demo script and fix some bugs (#7)
* add run_chat and fix some bugs * delete some chinese notes
1 parent b8fffc7 commit fadd43e

File tree

15 files changed

+379
-52
lines changed

15 files changed

+379
-52
lines changed

examples/run_chat.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple script to run a multi-agent chat conversation
4+
Usage: python scripts/run_chat.py
5+
"""
6+
7+
import asyncio
8+
import os
9+
import sys
10+
from pathlib import Path
11+
12+
# Add the project root to Python path
13+
project_root = Path(__file__).parent.parent
14+
sys.path.insert(0, str(project_root))
15+
16+
from tiny_chat.server import ChatServer
17+
from tiny_chat.messages import ChatBackground
18+
19+
20+
async def main():
21+
"""Run a simple multi-agent conversation"""
22+
23+
# Get API key from environment variable
24+
api_key = os.getenv("OPENAI_API_KEY")
25+
if not api_key:
26+
print("Warning: OPENAI_API_KEY not set. Some features may not work.")
27+
28+
# Create chat server
29+
server = ChatServer(api_key=api_key)
30+
31+
# Define agent configurations
32+
agent_configs = [
33+
{
34+
"name": "Alice",
35+
"type": "llm",
36+
"model": "gpt-4o-mini",
37+
"goal": "Be friendly and helpful in the conversation"
38+
},
39+
{
40+
"name": "Bob",
41+
"type": "llm",
42+
"model": "gpt-4o-mini",
43+
"goal": "Ask thoughtful questions and share interesting ideas"
44+
}
45+
]
46+
47+
# Create a simple background
48+
background = ChatBackground(
49+
scenario="Two friends meeting at a coffee shop",
50+
p1_background="Alice is a software engineer who loves hiking",
51+
p2_background="Bob is a teacher who enjoys reading science fiction",
52+
p1_goal="Have a pleasant conversation about weekend plans",
53+
p2_goal="Discuss recent books and outdoor activities",
54+
p1_name="Alice",
55+
p2_name="Bob"
56+
)
57+
58+
print("Starting multi-agent conversation...")
59+
print("=" * 50)
60+
61+
# Run the conversation
62+
await server.run_conversation(
63+
agent_configs=agent_configs,
64+
background=background,
65+
max_turns=10,
66+
enable_evaluation=True
67+
)
68+
69+
70+
if __name__ == "__main__":
71+
asyncio.run(main())

tiny_chat/__init__.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@
55
__version__ = '0.0.1'
66

77
from .agents import LLMAgent
8-
from .environment import ChatEnvironment
9-
from .evaluators import (
8+
from .envs import TinyChatEnvironment
9+
from .evaluator import (
1010
EpisodeLLMEvaluator,
1111
EvaluationDimension,
1212
EvaluationForMultipleAgents,
1313
EvaluationForTwoAgents,
1414
Evaluator,
1515
RuleBasedTerminatedEvaluator,
16-
SotopiaDimensions,
16+
TinyChatDimensions,
1717
unweighted_aggregate_evaluate,
1818
)
19-
from .generator import MessageGenerator
2019
from .messages import AgentAction, ChatBackground, Message, Observation, SimpleMessage
21-
from .profile import AgentProfile, RelationshipProfile, RelationshipType
20+
from .profile import BaseAgentProfile, BaseEnvironmentProfile, BaseRelationshipProfile
2221
from .server import ChatServer
2322

2423
__all__ = [
@@ -28,15 +27,14 @@
2827
'Observation',
2928
'AgentAction',
3029
'ChatBackground',
31-
'ChatEnvironment',
32-
'MessageGenerator',
33-
'AgentProfile',
34-
'RelationshipProfile',
35-
'RelationshipType',
30+
'TinyChatEnvironment',
31+
'BaseAgentProfile',
32+
'BaseEnvironmentProfile',
33+
'BaseRelationshipProfile',
3634
'Evaluator',
3735
'RuleBasedTerminatedEvaluator',
3836
'EpisodeLLMEvaluator',
39-
'SotopiaDimensions',
37+
'TinyChatDimensions',
4038
'EvaluationDimension',
4139
'EvaluationForTwoAgents',
4240
'EvaluationForMultipleAgents',

tiny_chat/agents/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1-
from .agents import LLMAgent
1+
from .agents import LLMAgent, HumanAgent
22

3-
__all__ = ['LLMAgent']
3+
__all__ = [
4+
'LLMAgent',
5+
'HumanAgent',
6+
]

tiny_chat/agents/agents.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from tiny_chat.generator import generate_agent
1+
from tiny_chat.generator import agenerate_action
22
from tiny_chat.messages import AgentAction, Message, Observation
3-
from tiny_chat.profile import AgentProfile
3+
from tiny_chat.profile import BaseAgentProfile
44

55

66
class LLMAgent:
@@ -10,17 +10,20 @@ def __init__(
1010
self,
1111
agent_name: str,
1212
model: str = 'gpt-4o-mini',
13+
agent_number: int = 1,
1314
api_key: str | None = None,
1415
goal: str | None = None,
15-
agent_profile: AgentProfile | None = None,
16+
agent_profile: BaseAgentProfile | None = None,
17+
script_like: bool = False,
1618
):
1719
self.agent_name = agent_name
1820
self.model = model
21+
self.agent_number = agent_number
1922
self.api_key = api_key
2023
self._goal = goal or 'Have a natural conversation'
2124
self.agent_profile = agent_profile
2225
self.message_history: list[Message] = []
23-
26+
self.script_like = script_like
2427
@property
2528
def goal(self) -> str:
2629
"""Get the agent's goal"""
@@ -32,12 +35,20 @@ def goal(self, goal: str) -> None:
3235
self._goal = goal
3336

3437
async def act(self, obs: Observation) -> AgentAction:
35-
self.recv_message('Environment', obs)
38+
self.receive_message('Environment', obs)
3639

3740
if len(obs.available_actions) == 1 and 'none' in obs.available_actions:
3841
return AgentAction(action_type='none', argument='')
3942
else:
40-
action = await generate_agent()
43+
action = await agenerate_action(
44+
model_name=self.model,
45+
history=self._build_context(obs),
46+
turn_number=obs.turn_number,
47+
action_types=obs.available_actions,
48+
agent=self.agent_name,
49+
goal=self.goal,
50+
script_like=self.script_like,
51+
)
4152
return action
4253

4354
def _build_context(self, observation: Observation) -> str:
@@ -67,6 +78,9 @@ def reset(self) -> None:
6778
"""Reset agent state"""
6879
self.message_history.clear()
6980

70-
def receive_message(self, message: Message) -> None:
81+
def receive_message(self, source: str, message: Message) -> None:
7182
"""Receive a message and add to history"""
7283
self.message_history.append(message)
84+
85+
class HumanAgent:
86+
pass

tiny_chat/envs/environment.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pydantic import validate_call
88

9-
from tiny_chat.evaluators import Evaluator, unweighted_aggregate_evaluate
9+
from tiny_chat.evaluator import Evaluator, unweighted_aggregate_evaluate
1010
from tiny_chat.messages import (
1111
ActionType,
1212
AgentAction,
@@ -211,6 +211,45 @@ def reset(
211211
),
212212
),
213213
}
214+
def get_turn_number(self) -> int:
215+
return self.turn_number
216+
217+
def is_terminated(self) -> bool:
218+
return self.turn_number >= self.max_turns
219+
220+
def get_observation(self, agent_name: str) -> Observation:
221+
# get last turn
222+
last_turn = ""
223+
if self.turn_number > 0 and self.inbox:
224+
last_actions = {}
225+
for source, message in self.inbox:
226+
if isinstance(message, AgentAction) and source != 'Environment':
227+
last_actions[source] = message
228+
if last_actions:
229+
last_turn = _actions_to_natural_language(last_actions)
230+
else:
231+
last_turn = self.background.to_natural_language()
232+
else:
233+
last_turn = self.background.to_natural_language()
234+
235+
agent_index = 0
236+
if hasattr(self, 'agents') and self.agents:
237+
try:
238+
agent_index = self.agents.index(agent_name)
239+
except ValueError:
240+
agent_index = 0
241+
242+
available_actions = ['none']
243+
if hasattr(self, 'action_mask') and len(self.action_mask) > agent_index:
244+
if self.action_mask[agent_index]:
245+
available_actions = list(self.available_action_types)
246+
247+
obs = Observation(
248+
last_turn=last_turn,
249+
turn_number=self.turn_number,
250+
available_actions=available_actions,
251+
)
252+
return obs
214253

215254
def _get_agent_background(self, agent: Any, agent_id: int) -> str:
216255
"""Extract background information from agent"""

tiny_chat/evaluator/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from .evaluators import (
2+
EpisodeLLMEvaluator,
3+
EvaluationDimension,
4+
EvaluationForMultipleAgents,
5+
EvaluationForTwoAgents,
6+
Evaluator,
7+
RuleBasedTerminatedEvaluator,
8+
TinyChatDimensions,
9+
unweighted_aggregate_evaluate,
10+
)
11+
12+
__all__ = [
13+
'Evaluator',
14+
'EpisodeLLMEvaluator',
15+
'EvaluationDimension',
16+
'EvaluationForMultipleAgents',
17+
'EvaluationForTwoAgents',
18+
'RuleBasedTerminatedEvaluator',
19+
'TinyChatDimensions',
20+
'unweighted_aggregate_evaluate',
21+
]

tiny_chat/evaluator/evaluators.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from pydantic import BaseModel, validate_call
77

8-
from .messages import (
8+
from tiny_chat.messages import (
99
AgentAction,
1010
Message,
1111
ScriptEnvironmentResponse,
@@ -98,11 +98,11 @@ class EpisodeLLMEvaluator(Evaluator, Generic[T_eval_dim]):
9898
def __init__(
9999
self,
100100
model_name: str,
101-
response_format_class: type[EvaluationForTwoAgents[T_eval_dim]],
101+
# response_format_class: type[EvaluationForTwoAgents[T_eval_dim]],
102102
) -> None:
103103
self.model_name = model_name
104104
self.prompt = ''
105-
self.response_format_class = response_format_class
105+
# self.response_format_class = response_format_class
106106

107107
def __call__(
108108
self, turn_number: int, messages: list[tuple[str, Message]]
@@ -155,7 +155,7 @@ class EvaluationDimension(BaseModel):
155155
pass
156156

157157

158-
class SotopiaDimensions(BaseModel):
158+
class TinyChatDimensions(BaseModel):
159159
"""Evaluation dimensions used in Sotopia"""
160160

161161
overall_score: tuple[str, float] = ('Overall score', 0.0)

tiny_chat/generator/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from .generate_template import agenerate_action, agenerate_goal
2+
3+
from .output_parsers import (
4+
EnvResponse,
5+
StrOutputParser,
6+
ScriptOutputParser,
7+
PydanticOutputParser,
8+
ListOfIntOutputParser,
9+
)
10+
__all__ = [
11+
"EnvResponse",
12+
"StrOutputParser",
13+
"ScriptOutputParser",
14+
"PydanticOutputParser",
15+
"ListOfIntOutputParser",
16+
"agenerate_action",
17+
"agenerate_goal",
18+
]

tiny_chat/generator/generate_template.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,22 @@
1111
from pydantic import validate_call
1212
from rich import print
1313
from rich.logging import RichHandler
14-
from sotopia.generation_utils.output_parsers import (
14+
from tiny_chat.generator.output_parsers import (
1515
EnvResponse,
1616
OutputParser,
1717
OutputType,
1818
PydanticOutputParser,
1919
ScriptOutputParser,
2020
StrOutputParser,
2121
)
22-
from sotopia.utils import format_docstring
22+
from tiny_chat.utils import format_docstring
2323

2424
from tiny_chat.messages import ActionType, AgentAction, ScriptBackground
25-
from tiny_chat.messages.message_classes import (
25+
from tiny_chat.messages import (
2626
ScriptInteraction,
2727
ScriptInteractionReturnType,
2828
)
29-
from tiny_chat.profile import EnvironmentProfile, RelationshipProfile
29+
from tiny_chat.profile import BaseEnvironmentProfile, BaseRelationshipProfile
3030

3131
# Configure logger
3232
log = logging.getLogger('sotopia.generation')
@@ -187,7 +187,7 @@ async def agenerate_env_profile(
187187
temperature: float = 0.7,
188188
bad_output_process_model: str | None = None,
189189
use_fixed_model_version: bool = True,
190-
) -> EnvironmentProfile:
190+
) -> BaseEnvironmentProfile:
191191
"""
192192
Using langchain to generate the background
193193
"""
@@ -204,7 +204,7 @@ async def agenerate_env_profile(
204204
inspiration_prompt=inspiration_prompt,
205205
examples=examples,
206206
),
207-
output_parser=PydanticOutputParser(pydantic_object=EnvironmentProfile),
207+
output_parser=PydanticOutputParser(pydantic_object=BaseEnvironmentProfile),
208208
temperature=temperature,
209209
bad_output_process_model=bad_output_process_model,
210210
use_fixed_model_version=use_fixed_model_version,
@@ -217,7 +217,7 @@ async def agenerate_relationship_profile(
217217
agents_profiles: list[str],
218218
bad_output_process_model: str | None = None,
219219
use_fixed_model_version: bool = True,
220-
) -> tuple[RelationshipProfile, str]:
220+
) -> tuple[BaseRelationshipProfile, str]:
221221
"""
222222
Using langchain to generate the background
223223
"""
@@ -232,7 +232,7 @@ async def agenerate_relationship_profile(
232232
input_values=dict(
233233
agent_profile=agent_profile,
234234
),
235-
output_parser=PydanticOutputParser(pydantic_object=RelationshipProfile),
235+
output_parser=PydanticOutputParser(pydantic_object=BaseRelationshipProfile),
236236
bad_output_process_model=bad_output_process_model,
237237
use_fixed_model_version=use_fixed_model_version,
238238
)

0 commit comments

Comments
 (0)