From 03f8d71adc10c3fb2f6bfbb1d766b69f3684b92a Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Sun, 6 Aug 2023 19:51:45 +0000 Subject: [PATCH 1/3] Make auto reply method pluggable --- flaml/autogen/agentchat/__init__.py | 3 +- .../contrib/math_user_proxy_agent.py | 2 +- flaml/autogen/agentchat/groupchat.py | 94 +++++++++++-------- flaml/autogen/agentchat/responsive_agent.py | 87 +++++++++++++---- notebook/autogen_agentchat_chess.ipynb | 25 +++-- notebook/autogen_agentchat_groupchat.ipynb | 15 ++- test/autogen/agentchat/test_groupchat.py | 40 +++++++- 7 files changed, 192 insertions(+), 74 deletions(-) diff --git a/flaml/autogen/agentchat/__init__.py b/flaml/autogen/agentchat/__init__.py index 27cc2e495e..bcbc643a84 100644 --- a/flaml/autogen/agentchat/__init__.py +++ b/flaml/autogen/agentchat/__init__.py @@ -2,12 +2,13 @@ from .responsive_agent import ResponsiveAgent from .assistant_agent import AssistantAgent from .user_proxy_agent import UserProxyAgent -from .groupchat import GroupChatManager +from .groupchat import GroupChat, GroupChatManager __all__ = [ "Agent", "ResponsiveAgent", "AssistantAgent", "UserProxyAgent", + "GroupChat", "GroupChatManager", ] diff --git a/flaml/autogen/agentchat/contrib/math_user_proxy_agent.py b/flaml/autogen/agentchat/contrib/math_user_proxy_agent.py index cc388f028c..702c043c91 100644 --- a/flaml/autogen/agentchat/contrib/math_user_proxy_agent.py +++ b/flaml/autogen/agentchat/contrib/math_user_proxy_agent.py @@ -165,7 +165,7 @@ def __init__( default_auto_reply=default_auto_reply, **kwargs, ) - self.register_auto_reply(Agent, self._generate_math_reply, 1) + self.register_auto_reply(Agent, MathUserProxyAgent._generate_math_reply, 1) # fixed var self._max_invalid_q_per_step = max_invalid_q_per_step diff --git a/flaml/autogen/agentchat/groupchat.py b/flaml/autogen/agentchat/groupchat.py index 51d99d012b..14faf36389 100644 --- a/flaml/autogen/agentchat/groupchat.py +++ b/flaml/autogen/agentchat/groupchat.py @@ -1,26 +1,63 @@ +from dataclasses import dataclass import sys from typing import Dict, List, Optional, Union from .agent import Agent from .responsive_agent import ResponsiveAgent -class GroupChatManager(ResponsiveAgent): - """(WIP) A chat manager agent that can manage a group chat of multiple agents.""" +@dataclass +class GroupChat: + """A group chat class that contains a list of agents and the maximum number of rounds.""" agents: List[Agent] - max_round: int + messages: List[Dict] + max_round: int = 10 - def _participant_roles(self): - return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents]) + @property + def agent_names(self) -> List[str]: + """Return the names of the agents in the group chat.""" + return [agent.name for agent in self.agents] + + def reset(self): + """Reset the group chat.""" + self.messages.clear() + + def agent_by_name(self, name: str) -> Agent: + """Find the next speaker based on the message.""" + return self.agents[self.agent_names.index(name)] - def _select_speaker_msg(self): + def next_agent(self, agent: Agent) -> Agent: + """Return the next agent in the list.""" + return self.agents[(self.agent_names.index(agent.name) + 1) % len(self.agents)] + + def select_speaker_msg(self): + """Return the message for selecting the next speaker.""" return f"""You are in a role play game. The following roles are available: {self._participant_roles()}. Read the following conversation. -Then select the next role from {self._agent_names} to play. Only return the role.""" +Then select the next role from {self.agent_names} to play. Only return the role.""" + + def select_speaker(self, last_speaker: Agent, selctor: ResponsiveAgent): + """Select the next speaker.""" + selctor.update_system_message(self.select_speaker_msg()) + final, name = selctor.generate_oai_reply(self.messages) + if not final: + # i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id + return self.next_agent(last_speaker) + try: + return self.agent_by_name(name) + except ValueError: + return self.next_agent(last_speaker) + + def _participant_roles(self): + return "\n".join([f"{agent.name}: {agent.system_message}" for agent in self.agents]) + + +class GroupChatManager(ResponsiveAgent): + """(WIP) A chat manager agent that can manage a group chat of multiple agents.""" def __init__( self, - max_round: Optional[int] = 10, + groupchat: GroupChat, name: Optional[str] = "chat_manager", # unlimited consecutive auto reply by default max_consecutive_auto_reply: Optional[int] = sys.maxsize, @@ -33,56 +70,35 @@ def __init__( name=name, max_consecutive_auto_reply=max_consecutive_auto_reply, human_input_mode=human_input_mode, + system_message=system_message, **kwargs, ) - self.register_auto_reply(Agent, self._generate_reply_for_participant) - self.max_round = max_round - self._agent_names = [] - self._messages = [] + self.register_auto_reply(Agent, GroupChatManager.run_chat, context=groupchat, reset_context=GroupChat.reset) # self._random = random.Random(seed) - def _generate_reply_for_participant( + def run_chat( self, messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None, + context: Optional[GroupChat] = None, ) -> Union[str, Dict, None]: - self._agent_names = [agent.name for agent in self.agents] + """Run a group chat.""" if messages is None: messages = self._oai_messages[sender] message = messages[-1] speaker = sender - for i in range(self.max_round): + for i in range(context.max_round): # set the name to speaker's name if the role is not function if message["role"] != "function": message["name"] = speaker.name - self._messages.append(message) + context.messages.append(message) # broadcast the message to all agents except the speaker - for agent in self.agents: + for agent in context.agents: if agent != speaker: self.send(message, agent, request_reply=False) - if i != self.max_round - 1: + if i != context.max_round - 1: # speaker selection msg from an agent - speaker = self._select_speaker(speaker) + speaker = context.select_speaker(speaker, self) speaker.send(speaker.generate_reply(sender=self), self, request_reply=False) message = self.last_message(speaker) return True, None - - def _select_speaker(self, last_speaker: Agent): - """Select the next speaker.""" - self.update_system_message(self._select_speaker_msg()) - final, name = self._generate_oai_reply(self._messages) - if not final: - # i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id - return self.agents[(self._agent_names.index(last_speaker.name) + 1) % len(self._agent_names)] - try: - return self.agent_by_name(name) - except ValueError: - return self.agents[(self._agent_names.index(last_speaker.name) + 1) % len(self._agent_names)] - - def agent_by_name(self, name: str) -> Agent: - """Find the next speaker based on the message.""" - return self.agents[self._agent_names.index(name)] - - def reset(self): - super().reset() - self._messages.clear() diff --git a/flaml/autogen/agentchat/responsive_agent.py b/flaml/autogen/agentchat/responsive_agent.py index 145a1a341f..58288a391f 100644 --- a/flaml/autogen/agentchat/responsive_agent.py +++ b/flaml/autogen/agentchat/responsive_agent.py @@ -1,6 +1,7 @@ from collections import defaultdict +import copy import json -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from flaml.autogen import oai from .agent import Agent from flaml.autogen.code_utils import DEFAULT_MODEL, UNKNOWN, execute_code, extract_code, infer_lang @@ -110,12 +111,19 @@ def __init__( self._default_auto_reply = default_auto_reply self._class_specific_reply = [] self.reply_at_receive = defaultdict(bool) - self.register_auto_reply(Agent, self._generate_oai_reply) - self.register_auto_reply(Agent, self._generate_code_execution_reply) - self.register_auto_reply(Agent, self._generate_function_call_reply) - self.register_auto_reply(Agent, self._check_termination_and_human_reply) + self.register_auto_reply(Agent, ResponsiveAgent.generate_oai_reply) + self.register_auto_reply(Agent, ResponsiveAgent.generate_code_execution_reply) + self.register_auto_reply(Agent, ResponsiveAgent.generate_function_call_reply) + self.register_auto_reply(Agent, ResponsiveAgent.check_termination_and_human_reply) - def register_auto_reply(self, class_type, reply_func: Callable, position: int = 0): + def register_auto_reply( + self, + class_type, + reply_func: Callable, + position: Optional[int] = 0, + context: Optional[Any] = None, + reset_context: Optional[Callable] = None, + ): """Register a class-specific reply function. The class-specific reply function will be called when the sender is an instance of the class_type. @@ -125,9 +133,33 @@ def register_auto_reply(self, class_type, reply_func: Callable, position: int = Args: class_type (Class): the class type. reply_func (Callable): the reply function. + The function takes a recipient agent, a list of messages, a sender agent and a context as input and returns a reply message. + ```python + def reply_func( + recipient: ResponsiveAgent, + messages: Optional[List[Dict]] = None, + sender: Optional[Agent] = None, + context: Optional[Any] = None, + ) -> Union[str, Dict, None]: + ``` position (int): the position of the reply function in the reply function list. + The function registered later will be checked earlier by default. + To change the order, set the position to a positive integer. + context (Any): the context to be passed to the reply function. + When an agent is reset, the context will be reset to the original value. + reset_context (Callable): the function to reset the context. + The function returns None. Signature: ```def reset_context(context: Any)``` """ - self._class_specific_reply.insert(position, (class_type, reply_func)) + self._class_specific_reply.insert( + position, + { + "class_type": class_type, + "reply_func": reply_func, + "context": copy.copy(context), + "init_context": context, + "reset_context": reset_context, + }, + ) @property def system_message(self): @@ -362,6 +394,11 @@ def reset(self): self.clear_history() self.reset_consecutive_auto_reply_counter() self.stop_reply_at_receive() + for class_specific_reply in self._class_specific_reply: + if class_specific_reply["reset_context"] is not None: + class_specific_reply["reset_context"](class_specific_reply["context"]) + else: + class_specific_reply["context"] = copy.copy(class_specific_reply["init_context"]) def stop_reply_at_receive(self, sender: Optional[Agent] = None): """Reset the reply_at_receive of the sender.""" @@ -388,28 +425,34 @@ def clear_history(self, agent: Optional[Agent] = None): else: self._oai_messages[agent].clear() - def _generate_oai_reply( + def generate_oai_reply( self, messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None, + context: Optional[Any] = None, ) -> Tuple[bool, Union[str, Dict, None]]: - if self.llm_config is False: + """Generate a reply using autogen.oai.""" + llm_config = self.llm_config if context is None else context + if llm_config is False: return False, None if messages is None: messages = self._oai_messages[sender] # TODO: #1143 handle token limit exceeded error response = oai.ChatCompletion.create( - context=messages[-1].pop("context", None), messages=self._oai_system_message + messages, **self.llm_config + context=messages[-1].pop("context", None), messages=self._oai_system_message + messages, **llm_config ) return True, oai.ChatCompletion.extract_text_or_function_call(response)[0] - def _generate_code_execution_reply( + def generate_code_execution_reply( self, messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None, + context: Optional[Any] = None, ): - if self._code_execution_config is False: + """Generate a reply using code execution.""" + code_execution_config = context if context is not None else self._code_execution_config + if code_execution_config is False: return False, None if messages is None: messages = self._oai_messages[sender] @@ -426,11 +469,15 @@ def _generate_code_execution_reply( exitcode2str = "execution succeeded" if exitcode == 0 else "execution failed" return True, f"exitcode: {exitcode} ({exitcode2str})\nCode output: {logs}" - def _generate_function_call_reply( + def generate_function_call_reply( self, messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None, + context: Optional[Any] = None, ): + """Generate a reply using function call.""" + if context is None: + context = self if messages is None: messages = self._oai_messages[sender] message = messages[-1] @@ -439,11 +486,15 @@ def _generate_function_call_reply( return True, func_return return False, None - def _check_termination_and_human_reply( + def check_termination_and_human_reply( self, messages: Optional[List[Dict]] = None, sender: Optional[Agent] = None, + context: Optional[Any] = None, ) -> Tuple[bool, Union[str, Dict, None]]: + """Check if the conversation should be terminated, and if human reply is provided.""" + if context is None: + context = self if messages is None: messages = self._oai_messages[sender] message = messages[-1] @@ -539,10 +590,12 @@ def generate_reply( assert messages is not None or sender is not None, "Either messages or sender must be provided." if sender is not None: for class_specifc_reply in self._class_specific_reply: - if isinstance(sender, class_specifc_reply[0]) and ( - not exclude or class_specifc_reply[1] not in exclude + if isinstance(sender, class_specifc_reply["class_type"]) and ( + not exclude or class_specifc_reply["reply_func"] not in exclude ): - final, reply = class_specifc_reply[1](messages, sender) + final, reply = class_specifc_reply["reply_func"]( + self, messages=messages, sender=sender, context=class_specifc_reply["context"] + ) if final: return reply return self._default_auto_reply diff --git a/notebook/autogen_agentchat_chess.ipynb b/notebook/autogen_agentchat_chess.ipynb index 65cf772d9b..7165e04eac 100644 --- a/notebook/autogen_agentchat_chess.ipynb +++ b/notebook/autogen_agentchat_chess.ipynb @@ -137,7 +137,7 @@ "outputs": [], "source": [ "from collections import defaultdict\n", - "from typing import Dict, List, Optional, Union\n", + "from typing import Any, Dict, List, Optional, Union\n", "\n", "sys_msg = \"\"\"You are an AI-powered chess board agent.\n", "You translate user's natural language input into legal UCI moves.\n", @@ -164,7 +164,7 @@ " llm_config={\"temperature\": 0.0, \"config_list\": config_list_gpt4},\n", " max_consecutive_auto_reply=10,\n", " )\n", - " self.register_auto_reply(autogen.ResponsiveAgent, self._generate_board_reply)\n", + " self.register_auto_reply(autogen.ResponsiveAgent, BoardAgent._generate_board_reply)\n", " self._board = board\n", " self._correct_move_messages = defaultdict(list)\n", "\n", @@ -172,6 +172,7 @@ " self,\n", " messages: Optional[List[Dict]] = None,\n", " sender: Optional[autogen.Agent] = None,\n", + " context: Optional[Any] = None,\n", " ) -> Union[str, Dict, None]:\n", " # Filter for messages that do not contain error.\n", " if messages is None:\n", @@ -179,7 +180,7 @@ " message = messages[-1]\n", " assert message.get(\"role\") == \"user\"\n", " # extract a UCI move from player's message\n", - " reply = self.generate_reply(self._correct_move_messages[sender] + [message], sender, exclude=[self._generate_board_reply])\n", + " reply = self.generate_reply(self._correct_move_messages[sender] + [message], sender, exclude=[BoardAgent._generate_board_reply])\n", " if isinstance(reply, str):\n", " uci_move = reply\n", " else:\n", @@ -242,8 +243,8 @@ " max_consecutive_auto_reply=max_turns,\n", " **kwargs,\n", " )\n", - " self.register_auto_reply(BoardAgent, self._generate_reply_for_board)\n", - " self.register_auto_reply(ChessPlayerAgent, self._generate_reply_for_player)\n", + " self.register_auto_reply(BoardAgent, ChessPlayerAgent._generate_reply_for_board)\n", + " self.register_auto_reply(ChessPlayerAgent, ChessPlayerAgent._generate_reply_for_player)\n", " self._board_agent = board_agent\n", " self.update_max_consecutive_auto_reply(self._board_agent.max_consecutive_auto_reply(), self._board_agent)\n", "\n", @@ -251,6 +252,7 @@ " self,\n", " messages: Optional[List[Dict]] = None,\n", " sender: Optional[autogen.Agent] = None,\n", + " context: Optional[Any] = None,\n", " ) -> Union[str, Dict, None]:\n", " if messages is None:\n", " messages = self._oai_messages[sender]\n", @@ -260,7 +262,7 @@ " if last_message[\"content\"].startswith(\"Error\"):\n", " # try again\n", " last_message[\"role\"] = \"system\"\n", - " return True, self.generate_reply(messages + board_state_msg, sender, exclude=[self._generate_reply_for_board])\n", + " return True, self.generate_reply(messages + board_state_msg, sender, exclude=[ChessPlayerAgent._generate_reply_for_board])\n", " else:\n", " return True, None\n", "\n", @@ -268,13 +270,14 @@ " self,\n", " messages: Optional[List[Dict]] = None,\n", " sender: Optional[autogen.Agent] = None,\n", + " context: Optional[Any] = None,\n", " ) -> Union[str, Dict, None]:\n", " if messages is None:\n", " messages = self._oai_messages[sender]\n", " # add a system message about the current state of the board.\n", " board_state_msg = [{\"role\": \"system\", \"content\": f\"Current board:\\n{self._board_agent._board}\"}]\n", " # propose a reply which will be sent to the board agent for verification.\n", - " message = self.generate_reply(messages + board_state_msg, sender, exclude=[self._generate_reply_for_player])\n", + " message = self.generate_reply(messages + board_state_msg, sender, exclude=[ChessPlayerAgent._generate_reply_for_player])\n", " if message is None:\n", " return True, None\n", " # converse with the board until a legal move is made or max allowed retries.\n", @@ -467,7 +470,13 @@ "g1f3. \n", "Aiming to control the center of the board. Your move.\n", "\n", - "--------------------------------------------------------------------------------\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[33mPlayer black\u001b[0m (to BoardAgent):\n", "\n", "g8f6. \n", diff --git a/notebook/autogen_agentchat_groupchat.ipynb b/notebook/autogen_agentchat_groupchat.ipynb index 3d46611c09..c482702328 100644 --- a/notebook/autogen_agentchat_groupchat.ipynb +++ b/notebook/autogen_agentchat_groupchat.ipynb @@ -124,7 +124,6 @@ "outputs": [], "source": [ "llm_config = {\"config_list\": config_list_gpt4}\n", - "group_chat_manager = autogen.GroupChatManager(max_round=4, llm_config=llm_config)\n", "human = autogen.UserProxyAgent(\n", " name=\"Human\",\n", " system_message=\"A human admin.\",\n", @@ -138,8 +137,8 @@ " system_message=\"Code reviewer. Prevent code execution if unsafe or not well documented. Suggest changes. Otherwise, approve and return the final code to execute.\",\n", " llm_config=llm_config,\n", ")\n", - "\n", - "group_chat_manager.agents = [human, alice, bob]" + "groupchat = autogen.GroupChat(agents=[human, alice, bob], messages=[], max_round=4)\n", + "manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)" ] }, { @@ -416,7 +415,13 @@ "\n", "Always use this script carefully because web-scraping isn't always reliable or legal on all web pages. Always ensure you have express permission or that the website's terms and conditions don't forbid this kind of usage.\n", "\n", - "--------------------------------------------------------------------------------\n", + "--------------------------------------------------------------------------------\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "\u001b[31m\n", ">>>>>>>> NO HUMAN INPUT RECEIVED.\u001b[0m\n", "\u001b[31m\n", @@ -454,7 +459,7 @@ } ], "source": [ - "human.initiate_chat(group_chat_manager, message=\"find a latest paper about generative agents\")" + "human.initiate_chat(manager, message=\"find a latest paper about generative agents\")" ] } ], diff --git a/test/autogen/agentchat/test_groupchat.py b/test/autogen/agentchat/test_groupchat.py index 33c684b93b..6873c3c69b 100644 --- a/test/autogen/agentchat/test_groupchat.py +++ b/test/autogen/agentchat/test_groupchat.py @@ -2,7 +2,6 @@ def test_chat_manager(): - group_chat_manager = autogen.GroupChatManager(max_round=2, llm_config=False) agent1 = autogen.ResponsiveAgent( "alice", max_consecutive_auto_reply=2, @@ -17,17 +16,52 @@ def test_chat_manager(): llm_config=False, default_auto_reply="This is bob speaking.", ) - group_chat_manager.agents = [agent1, agent2] + groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=2) + group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=False) agent1.initiate_chat(group_chat_manager, message="hello") assert len(agent1.chat_messages[group_chat_manager]) == 2 + assert len(groupchat.messages) == 2 group_chat_manager.reset() + assert len(groupchat.messages) == 0 agent1.reset() agent2.reset() agent2.initiate_chat(group_chat_manager, message="hello") + assert len(groupchat.messages) == 2 + + +def test_plugin(): + # Give another Agent class ability to manage group chat + agent1 = autogen.ResponsiveAgent( + "alice", + max_consecutive_auto_reply=2, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is alice sepaking.", + ) + agent2 = autogen.ResponsiveAgent( + "bob", + max_consecutive_auto_reply=2, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is bob speaking.", + ) + groupchat = autogen.GroupChat(agents=[agent1, agent2], messages=[], max_round=2) + group_chat_manager = autogen.ResponsiveAgent(name="deputy_manager", llm_config=False) + group_chat_manager.register_auto_reply( + autogen.Agent, + reply_func=autogen.GroupChatManager.run_chat, + context=groupchat, + reset_context=autogen.GroupChat.reset, + ) + agent1.initiate_chat(group_chat_manager, message="hello") + + assert len(agent1.chat_messages[group_chat_manager]) == 2 + assert len(groupchat.messages) == 2 if __name__ == "__main__": # test_broadcast() - test_chat_manager() + # test_chat_manager() + test_plugin() From cc2e603f4056c50a34e24c02d4c569c8aa676047 Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Mon, 7 Aug 2023 15:22:23 +0000 Subject: [PATCH 2/3] allow richer trigger types --- flaml/autogen/agentchat/responsive_agent.py | 55 +++++++++++++------ flaml/autogen/agentchat/user_proxy_agent.py | 2 +- .../agentchat/test_responsive_agent.py | 29 +++++++++- 3 files changed, 66 insertions(+), 20 deletions(-) diff --git a/flaml/autogen/agentchat/responsive_agent.py b/flaml/autogen/agentchat/responsive_agent.py index 58288a391f..ef2cbfd033 100644 --- a/flaml/autogen/agentchat/responsive_agent.py +++ b/flaml/autogen/agentchat/responsive_agent.py @@ -1,7 +1,7 @@ from collections import defaultdict import copy import json -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union from flaml.autogen import oai from .agent import Agent from flaml.autogen.code_utils import DEFAULT_MODEL, UNKNOWN, execute_code, extract_code, infer_lang @@ -109,7 +109,7 @@ def __init__( self._max_consecutive_auto_reply_dict = defaultdict(self.max_consecutive_auto_reply) self._function_map = {} if function_map is None else function_map self._default_auto_reply = default_auto_reply - self._class_specific_reply = [] + self._reply_func_list = [] self.reply_at_receive = defaultdict(bool) self.register_auto_reply(Agent, ResponsiveAgent.generate_oai_reply) self.register_auto_reply(Agent, ResponsiveAgent.generate_code_execution_reply) @@ -118,20 +118,24 @@ def __init__( def register_auto_reply( self, - class_type, + trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool]], reply_func: Callable, position: Optional[int] = 0, context: Optional[Any] = None, reset_context: Optional[Callable] = None, ): - """Register a class-specific reply function. + """Register a reply function. - The class-specific reply function will be called when the sender is an instance of the class_type. + The reply function will be called when the trigger matches the sender. The function registered later will be checked earlier by default. To change the order, set the position to a positive integer. Args: - class_type (Class): the class type. + trigger (Agent class, str, Agent instance, or Callable): the trigger. + - If a class is provided, the reply function will be called when the sender is an instance of the class. + - If a string is provided, the reply function will be called when the sender's name matches the string. + - If an agent instance is provided, the reply function will be called when the sender is the agent instance. + - If a callable is provided, the reply function will be called when the callable returns True. reply_func (Callable): the reply function. The function takes a recipient agent, a list of messages, a sender agent and a context as input and returns a reply message. ```python @@ -150,10 +154,10 @@ def reply_func( reset_context (Callable): the function to reset the context. The function returns None. Signature: ```def reset_context(context: Any)``` """ - self._class_specific_reply.insert( + self._reply_func_list.insert( position, { - "class_type": class_type, + "trigger": trigger, "reply_func": reply_func, "context": copy.copy(context), "init_context": context, @@ -394,11 +398,11 @@ def reset(self): self.clear_history() self.reset_consecutive_auto_reply_counter() self.stop_reply_at_receive() - for class_specific_reply in self._class_specific_reply: - if class_specific_reply["reset_context"] is not None: - class_specific_reply["reset_context"](class_specific_reply["context"]) + for reply_func_tuple in self._reply_func_list: + if reply_func_tuple["reset_context"] is not None: + reply_func_tuple["reset_context"](reply_func_tuple["context"]) else: - class_specific_reply["context"] = copy.copy(class_specific_reply["init_context"]) + reply_func_tuple["context"] = copy.copy(reply_func_tuple["init_context"]) def stop_reply_at_receive(self, sender: Optional[Agent] = None): """Reset the reply_at_receive of the sender.""" @@ -589,17 +593,32 @@ def generate_reply( """ assert messages is not None or sender is not None, "Either messages or sender must be provided." if sender is not None: - for class_specifc_reply in self._class_specific_reply: - if isinstance(sender, class_specifc_reply["class_type"]) and ( - not exclude or class_specifc_reply["reply_func"] not in exclude - ): - final, reply = class_specifc_reply["reply_func"]( - self, messages=messages, sender=sender, context=class_specifc_reply["context"] + for reply_func_tuple in self._reply_func_list: + if exclude and reply_func_tuple["reply_func"] in exclude: + continue + if self._match_trigger(reply_func_tuple["trigger"], sender): + final, reply = reply_func_tuple["reply_func"]( + self, messages=messages, sender=sender, context=reply_func_tuple["context"] ) if final: return reply return self._default_auto_reply + def _match_trigger(self, trigger, sender): + """Check if the sender matches the trigger.""" + if isinstance(trigger, str): + return trigger == sender.name + elif isinstance(trigger, type): + return isinstance(sender, trigger) + elif isinstance(trigger, Agent): + return trigger == sender + elif isinstance(trigger, Callable): + return trigger(sender) + elif isinstance(trigger, list): + return any(self._match_trigger(t, sender) for t in trigger) + else: + raise ValueError(f"Unsupported trigger type: {type(trigger)}") + def get_human_input(self, prompt: str) -> str: """Get human input. diff --git a/flaml/autogen/agentchat/user_proxy_agent.py b/flaml/autogen/agentchat/user_proxy_agent.py index 7f19d6ee9c..da73501816 100644 --- a/flaml/autogen/agentchat/user_proxy_agent.py +++ b/flaml/autogen/agentchat/user_proxy_agent.py @@ -8,7 +8,7 @@ class UserProxyAgent(ResponsiveAgent): UserProxyAgent is a subclass of ResponsiveAgent configured with `human_input_mode` to ALWAYS and `llm_config` to False. By default, the agent will prompt for human input every time a message is received. Code execution is enabled by default. LLM-based auto reply is disabled by default. - To modify auto reply, register a method with `register_class_specific_reply`. + To modify auto reply, register a method with (`register_auto_reply`)[responsive_agent#register_auto_reply]. The method should have a similar signature with `_generate_oai_reply` method. To modify the way to get human input, override `get_human_input` method. To modify the way to execute code blocks, single code block, or function call, override `execute_code_blocks`, diff --git a/test/autogen/agentchat/test_responsive_agent.py b/test/autogen/agentchat/test_responsive_agent.py index 9c9e39dad4..df961c5a07 100644 --- a/test/autogen/agentchat/test_responsive_agent.py +++ b/test/autogen/agentchat/test_responsive_agent.py @@ -2,6 +2,32 @@ from flaml.autogen.agentchat import ResponsiveAgent +def test_trigger(): + agent = ResponsiveAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER") + agent1 = ResponsiveAgent("a1", max_consecutive_auto_reply=0, human_input_mode="NEVER") + agent.register_auto_reply(agent1, lambda recipient, messages, sender, context: (True, "hello")) + agent1.initiate_chat(agent, message="hi") + assert agent1.last_message(agent)["content"] == "hello" + agent.register_auto_reply("a1", lambda recipient, messages, sender, context: (True, "hello a1")) + agent1.initiate_chat(agent, message="hi") + assert agent1.last_message(agent)["content"] == "hello a1" + agent.register_auto_reply( + ResponsiveAgent, lambda recipient, messages, sender, context: (True, "hello responsive agent") + ) + agent1.initiate_chat(agent, message="hi") + assert agent1.last_message(agent)["content"] == "hello responsive agent" + agent.register_auto_reply( + lambda sender: sender.name.startswith("a"), lambda recipient, messages, sender, context: (True, "hello a") + ) + agent1.initiate_chat(agent, message="hi") + assert agent1.last_message(agent)["content"] == "hello a" + agent.register_auto_reply( + lambda sender: sender.name.startswith("b"), lambda recipient, messages, sender, context: (True, "hello b") + ) + agent1.initiate_chat(agent, message="hi") + assert agent1.last_message(agent)["content"] == "hello a" + + def test_context(): agent = ResponsiveAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER") agent1 = ResponsiveAgent("a1", max_consecutive_auto_reply=0, human_input_mode="NEVER") @@ -117,6 +143,7 @@ def test_responsive_agent(): if __name__ == "__main__": - test_context() + test_trigger() + # test_context() # test_max_consecutive_auto_reply() # test_responsive_agent(pytest.monkeypatch) From 185e0cdda707ed232c994cb23d20d49b07408d63 Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Mon, 7 Aug 2023 15:31:47 +0000 Subject: [PATCH 3/3] test list --- flaml/autogen/agentchat/responsive_agent.py | 7 +++++-- test/autogen/agentchat/test_responsive_agent.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/flaml/autogen/agentchat/responsive_agent.py b/flaml/autogen/agentchat/responsive_agent.py index ef2cbfd033..143ea8a225 100644 --- a/flaml/autogen/agentchat/responsive_agent.py +++ b/flaml/autogen/agentchat/responsive_agent.py @@ -118,7 +118,7 @@ def __init__( def register_auto_reply( self, - trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool]], + trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List], reply_func: Callable, position: Optional[int] = 0, context: Optional[Any] = None, @@ -131,11 +131,12 @@ def register_auto_reply( To change the order, set the position to a positive integer. Args: - trigger (Agent class, str, Agent instance, or Callable): the trigger. + trigger (Agent class, str, Agent instance, callable, or list): the trigger. - If a class is provided, the reply function will be called when the sender is an instance of the class. - If a string is provided, the reply function will be called when the sender's name matches the string. - If an agent instance is provided, the reply function will be called when the sender is the agent instance. - If a callable is provided, the reply function will be called when the callable returns True. + - If a list is provided, the reply function will be called when any of the triggers in the list is activated. reply_func (Callable): the reply function. The function takes a recipient agent, a list of messages, a sender agent and a context as input and returns a reply message. ```python @@ -154,6 +155,8 @@ def reply_func( reset_context (Callable): the function to reset the context. The function returns None. Signature: ```def reset_context(context: Any)``` """ + if not isinstance(trigger, (type, str, Agent, Callable, list)): + raise ValueError("trigger must be a class, a string, an agent, a callable or a list.") self._reply_func_list.insert( position, { diff --git a/test/autogen/agentchat/test_responsive_agent.py b/test/autogen/agentchat/test_responsive_agent.py index df961c5a07..95bd0f83f3 100644 --- a/test/autogen/agentchat/test_responsive_agent.py +++ b/test/autogen/agentchat/test_responsive_agent.py @@ -26,6 +26,18 @@ def test_trigger(): ) agent1.initiate_chat(agent, message="hi") assert agent1.last_message(agent)["content"] == "hello a" + agent.register_auto_reply( + ["agent2", agent1], lambda recipient, messages, sender, context: (True, "hello agent2 or agent1") + ) + agent1.initiate_chat(agent, message="hi") + assert agent1.last_message(agent)["content"] == "hello agent2 or agent1" + agent.register_auto_reply( + ["agent2", "agent3"], lambda recipient, messages, sender, context: (True, "hello agent2 or agent3") + ) + agent1.initiate_chat(agent, message="hi") + assert agent1.last_message(agent)["content"] == "hello agent2 or agent1" + pytest.raises(ValueError, agent.register_auto_reply, 1, lambda recipient, messages, sender, context: (True, "hi")) + pytest.raises(ValueError, agent._match_trigger, 1, agent1) def test_context():