Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.generate_async_function_call_reply)
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)

def register_reply(
Expand Down Expand Up @@ -661,6 +662,28 @@ def generate_function_call_reply(
return True, func_return
return False, None

async def generate_async_function_call_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
):
"""Generate a reply using async function call."""
if config is None:
config = self
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
if "function_call" in message:
func_call = message["function_call"]
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)
if func and asyncio.coroutines.iscoroutinefunction(func):
_, func_return = await self.a_execute_function(func_call)
return True, func_return

return False, None

def check_termination_and_human_reply(
self,
messages: Optional[List[Dict]] = None,
Expand Down Expand Up @@ -1002,6 +1025,56 @@ def execute_function(self, func_call):
"content": str(content),
}

async def a_execute_function(self, func_call):
"""Execute an async function call and return the result.

Override this function to modify the way async functions are executed.

Args:
func_call: a dictionary extracted from openai message at key "function_call" with keys "name" and "arguments".

Returns:
A tuple of (is_exec_success, result_dict).
is_exec_success (boolean): whether the execution is successful.
result_dict: a dictionary with keys "name", "role", and "content". Value of "role" is "function".
"""
func_name = func_call.get("name", "")
func = self._function_map.get(func_name, None)

is_exec_success = False
if func is not None:
# Extract arguments from a json-like string and put it into a dict.
input_string = self._format_json_str(func_call.get("arguments", "{}"))
try:
arguments = json.loads(input_string)
except json.JSONDecodeError as e:
arguments = None
content = f"Error: {e}\n You argument should follow json format."

# Try to execute the function
if arguments is not None:
print(
colored(f"\n>>>>>>>> EXECUTING ASYNC FUNCTION {func_name}...", "magenta"),
flush=True,
)
try:
if asyncio.coroutines.iscoroutinefunction(func):
content = await func(**arguments)
else:
# Fallback to sync function if the function is not async
content = func(**arguments)
is_exec_success = True
except Exception as e:
content = f"Error: {e}"
else:
content = f"Error: Function {func_name} not found."

return is_exec_success, {
"name": func_name,
"role": "function",
"content": str(content),
}

def generate_init_message(self, **context) -> Union[str, Dict]:
"""Generate the initial message for the agent.

Expand Down
50 changes: 50 additions & 0 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ def __init__(
system_message=system_message,
**kwargs,
)
# Order of register_reply is important.
# Allow sync chat if initiated using initiate_chat
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
# Allow async chat if initiated using a_initiate_chat
self.register_reply(Agent, GroupChatManager.a_run_chat, config=groupchat, reset_config=GroupChat.reset)

# self._random = random.Random(seed)

def run_chat(
Expand Down Expand Up @@ -177,3 +182,48 @@ def run_chat(
speaker.send(reply, self, request_reply=False)
message = self.last_message(speaker)
return True, None

async def a_run_chat(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[GroupChat] = None,
):
"""Run a group chat asynchronously."""
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
speaker = sender
groupchat = config
for i in range(groupchat.max_round):
# set the name to speaker's name if the role is not function
if message["role"] != "function":
message["name"] = speaker.name
groupchat.messages.append(message)
# broadcast the message to all agents except the speaker
for agent in groupchat.agents:
if agent != speaker:
await self.a_send(message, agent, request_reply=False, silent=True)
if i == groupchat.max_round - 1:
# the last round
break
try:
# select the next speaker
speaker = groupchat.select_speaker(speaker, self)
# let the speaker speak
reply = await speaker.a_generate_reply(sender=self)
except KeyboardInterrupt:
# let the admin agent speak if interrupted
if groupchat.admin_name in groupchat.agent_names:
# admin agent is one of the participants
speaker = groupchat.agent_by_name(groupchat.admin_name)
reply = await speaker.a_generate_reply(sender=self)
else:
# admin agent is not found in the participants
raise
if reply is None:
break
# The speaker sends the message without requesting a reply
await speaker.a_send(reply, self, request_reply=False)
message = self.last_message(speaker)
return True, None
61 changes: 61 additions & 0 deletions test/test_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,68 @@ def get_number():
assert user.execute_function(func_call)[1]["content"] == "42"


@pytest.mark.asyncio
async def test_a_execute_function():
from autogen.agentchat import UserProxyAgent
import time

# Create an async function
async def add_num(num_to_be_added):
given_num = 10
time.sleep(1)
return num_to_be_added + given_num

user = UserProxyAgent(name="test", function_map={"add_num": add_num})
correct_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}

# Asset coroutine doesn't match.
assert user.execute_function(func_call=correct_args)[1]["content"] != "15"
# Asset awaited coroutine does match.
assert (await user.a_execute_function(func_call=correct_args))[1]["content"] == "15"

# function name called is wrong or doesn't exist
wrong_func_name = {"name": "subtract_num", "arguments": '{ "num_to_be_added": 5 }'}
assert "Error: Function" in (await user.a_execute_function(func_call=wrong_func_name))[1]["content"]

# arguments passed is not in correct json format
wrong_json_format = {
"name": "add_num",
"arguments": '{ "num_to_be_added": 5, given_num: 10 }',
} # should be "given_num" with quotes
assert (
"You argument should follow json format."
in (await user.a_execute_function(func_call=wrong_json_format))[1]["content"]
)

# function execution error with wrong arguments passed
wrong_args = {"name": "add_num", "arguments": '{ "num_to_be_added": 5, "given_num": 10 }'}
assert "Error: " in (await user.a_execute_function(func_call=wrong_args))[1]["content"]

# 2. test calling a class method
class AddNum:
def __init__(self, given_num):
self.given_num = given_num

def add(self, num_to_be_added):
self.given_num = num_to_be_added + self.given_num
return self.given_num

user = UserProxyAgent(name="test", function_map={"add_num": AddNum(given_num=10).add})
func_call = {"name": "add_num", "arguments": '{ "num_to_be_added": 5 }'}
assert (await user.a_execute_function(func_call=func_call))[1]["content"] == "15"
assert (await user.a_execute_function(func_call=func_call))[1]["content"] == "20"

# 3. test calling a function with no arguments
def get_number():
return 42

user = UserProxyAgent("user", function_map={"get_number": get_number})
func_call = {"name": "get_number", "arguments": "{}"}
assert (await user.a_execute_function(func_call))[1]["content"] == "42"


if __name__ == "__main__":
test_json_extraction()
test_execute_function()
test_a_execute_function()
test_eval_math_responses()