diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 5f032203b766..242bf5cdeaa1 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -1,5 +1,4 @@ import asyncio -import logging import uuid from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence @@ -15,7 +14,6 @@ ) from pydantic import BaseModel, ValidationError -from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TaskResult, Team, TerminationCondition from ...messages import ( BaseAgentEvent, @@ -27,11 +25,16 @@ ) from ...state import TeamState from ._chat_agent_container import ChatAgentContainer -from ._events import GroupChatPause, GroupChatReset, GroupChatResume, GroupChatStart, GroupChatTermination +from ._events import ( + GroupChatPause, + GroupChatReset, + GroupChatResume, + GroupChatStart, + GroupChatTermination, + SerializableException, +) from ._sequential_routed_agent import SequentialRoutedAgent -event_logger = logging.getLogger(EVENT_LOGGER_NAME) - class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]): """The base class for group chat teams. @@ -447,13 +450,26 @@ async def stop_runtime() -> None: try: # This will propagate any exceptions raised. await self._runtime.stop_when_idle() - finally: + # Put a termination message in the queue to indicate that the group chat is stopped for whatever reason + # but not due to an exception. + await self._output_message_queue.put( + GroupChatTermination( + message=StopMessage( + content="The group chat is stopped.", source=self._group_chat_manager_name + ) + ) + ) + except Exception as e: # Stop the consumption of messages and end the stream. - # NOTE: we also need to put a GroupChatTermination event here because when the group chat + # NOTE: we also need to put a GroupChatTermination event here because when the runtime # has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue. + # This may not be necessary if the group chat manager is able to handle the exception and put the event in the queue. await self._output_message_queue.put( GroupChatTermination( - message=StopMessage(content="Exception occurred.", source=self._group_chat_manager_name) + message=StopMessage( + content="An exception occurred in the runtime.", source=self._group_chat_manager_name + ), + error=SerializableException.from_exception(e), ) ) @@ -481,11 +497,10 @@ async def stop_runtime() -> None: # Wait for the next message, this will raise an exception if the task is cancelled. message = await message_future if isinstance(message, GroupChatTermination): - # If the message is None, it means the group chat has terminated. - # TODO: how do we handle termination when the runtime is not embedded - # and there is an exception in the group chat? - # The group chat manager may not be able to put a GroupChatTermination event in the queue, - # and this loop will never end. + # If the message contains an error, we need to raise it here. + # This will stop the team and propagate the error. + if message.error is not None: + raise RuntimeError(str(message.error)) stop_reason = message.message.content break yield message diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index 2f9c0a1e3a3a..0005aff1e729 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -8,6 +8,7 @@ from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, StopMessage from ._events import ( GroupChatAgentResponse, + GroupChatError, GroupChatMessage, GroupChatPause, GroupChatRequestPublish, @@ -15,6 +16,7 @@ GroupChatResume, GroupChatStart, GroupChatTermination, + SerializableException, ) from ._sequential_routed_agent import SequentialRoutedAgent @@ -140,58 +142,65 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No @event async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None: - # Append the message to the message thread and construct the delta. - delta: List[BaseAgentEvent | BaseChatMessage] = [] - if message.agent_response.inner_messages is not None: - for inner_message in message.agent_response.inner_messages: - self._message_thread.append(inner_message) - delta.append(inner_message) - self._message_thread.append(message.agent_response.chat_message) - delta.append(message.agent_response.chat_message) - - # Check if the conversation should be terminated. - if self._termination_condition is not None: - stop_message = await self._termination_condition(delta) - if stop_message is not None: - # Reset the termination conditions and turn count. - await self._termination_condition.reset() - self._current_turn = 0 - # Signal termination to the caller of the team. - await self._signal_termination(stop_message) - # Stop the group chat. - return - - # Increment the turn count. - self._current_turn += 1 - # Check if the maximum number of turns has been reached. - if self._max_turns is not None: - if self._current_turn >= self._max_turns: - stop_message = StopMessage( - content=f"Maximum number of turns {self._max_turns} reached.", - source=self._name, - ) - # Reset the termination conditions and turn count. - if self._termination_condition is not None: + try: + # Append the message to the message thread and construct the delta. + delta: List[BaseAgentEvent | BaseChatMessage] = [] + if message.agent_response.inner_messages is not None: + for inner_message in message.agent_response.inner_messages: + self._message_thread.append(inner_message) + delta.append(inner_message) + self._message_thread.append(message.agent_response.chat_message) + delta.append(message.agent_response.chat_message) + + # Check if the conversation should be terminated. + if self._termination_condition is not None: + stop_message = await self._termination_condition(delta) + if stop_message is not None: + # Reset the termination conditions and turn count. await self._termination_condition.reset() - self._current_turn = 0 - # Signal termination to the caller of the team. - await self._signal_termination(stop_message) - # Stop the group chat. - return + self._current_turn = 0 + # Signal termination to the caller of the team. + await self._signal_termination(stop_message) + # Stop the group chat. + return - # Select a speaker to continue the conversation. - speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) - # Link the select speaker future to the cancellation token. - ctx.cancellation_token.link_future(speaker_name_future) - speaker_name = await speaker_name_future - if speaker_name not in self._participant_name_to_topic_type: - raise RuntimeError(f"Speaker {speaker_name} not found in participant names.") - speaker_topic_type = self._participant_name_to_topic_type[speaker_name] - await self.publish_message( - GroupChatRequestPublish(), - topic_id=DefaultTopicId(type=speaker_topic_type), - cancellation_token=ctx.cancellation_token, - ) + # Increment the turn count. + self._current_turn += 1 + # Check if the maximum number of turns has been reached. + if self._max_turns is not None: + if self._current_turn >= self._max_turns: + stop_message = StopMessage( + content=f"Maximum number of turns {self._max_turns} reached.", + source=self._name, + ) + # Reset the termination conditions and turn count. + if self._termination_condition is not None: + await self._termination_condition.reset() + self._current_turn = 0 + # Signal termination to the caller of the team. + await self._signal_termination(stop_message) + # Stop the group chat. + return + + # Select a speaker to continue the conversation. + speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread)) + # Link the select speaker future to the cancellation token. + ctx.cancellation_token.link_future(speaker_name_future) + speaker_name = await speaker_name_future + if speaker_name not in self._participant_name_to_topic_type: + raise RuntimeError(f"Speaker {speaker_name} not found in participant names.") + speaker_topic_type = self._participant_name_to_topic_type[speaker_name] + await self.publish_message( + GroupChatRequestPublish(), + topic_id=DefaultTopicId(type=speaker_topic_type), + cancellation_token=ctx.cancellation_token, + ) + except Exception as e: + # Handle the exception and signal termination with an error. + error = SerializableException.from_exception(e) + await self._signal_termination_with_error(error) + # Raise the exception to the runtime. + raise async def _signal_termination(self, message: StopMessage) -> None: termination_event = GroupChatTermination(message=message) @@ -203,11 +212,28 @@ async def _signal_termination(self, message: StopMessage) -> None: # Put the termination event in the output message queue. await self._output_message_queue.put(termination_event) + async def _signal_termination_with_error(self, error: SerializableException) -> None: + termination_event = GroupChatTermination( + message=StopMessage(content="An error occurred in the group chat.", source=self._name), error=error + ) + # Log the termination event. + await self.publish_message( + termination_event, + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + # Put the termination event in the output message queue. + await self._output_message_queue.put(termination_event) + @event async def handle_group_chat_message(self, message: GroupChatMessage, ctx: MessageContext) -> None: """Handle a group chat message by appending the content to its output message queue.""" await self._output_message_queue.put(message.message) + @event + async def handle_group_chat_error(self, message: GroupChatError, ctx: MessageContext) -> None: + """Handle a group chat error by logging the error and signaling termination.""" + await self._signal_termination_with_error(message.error) + @rpc async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None: """Reset the group chat manager. Calling :meth:`reset` to reset the group chat manager diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py index fa74b8f9852b..69faeb49174c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -8,12 +8,14 @@ from ...state import ChatAgentContainerState from ._events import ( GroupChatAgentResponse, + GroupChatError, GroupChatMessage, GroupChatPause, GroupChatRequestPublish, GroupChatReset, GroupChatResume, GroupChatStart, + SerializableException, ) from ._sequential_routed_agent import SequentialRoutedAgent @@ -71,24 +73,36 @@ async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> No async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageContext) -> None: """Handle a content request event by passing the messages in the buffer to the delegate agent and publish the response.""" - # Pass the messages in the buffer to the delegate agent. - response: Response | None = None - async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token): - if isinstance(msg, Response): - await self._log_message(msg.chat_message) - response = msg - else: - await self._log_message(msg) - if response is None: - raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.") - - # Publish the response to the group chat. - self._message_buffer.clear() - await self.publish_message( - GroupChatAgentResponse(agent_response=response), - topic_id=DefaultTopicId(type=self._parent_topic_type), - cancellation_token=ctx.cancellation_token, - ) + try: + # Pass the messages in the buffer to the delegate agent. + response: Response | None = None + async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token): + if isinstance(msg, Response): + await self._log_message(msg.chat_message) + response = msg + else: + await self._log_message(msg) + if response is None: + raise ValueError( + "The agent did not produce a final response. Check the agent's on_messages_stream method." + ) + # Publish the response to the group chat. + self._message_buffer.clear() + await self.publish_message( + GroupChatAgentResponse(agent_response=response), + topic_id=DefaultTopicId(type=self._parent_topic_type), + cancellation_token=ctx.cancellation_token, + ) + except Exception as e: + # Publish the error to the group chat. + error_message = SerializableException.from_exception(e) + await self.publish_message( + GroupChatError(error=error_message), + topic_id=DefaultTopicId(type=self._parent_topic_type), + cancellation_token=ctx.cancellation_token, + ) + # Raise the error to the runtime. + raise def _buffer_message(self, message: BaseChatMessage) -> None: if not self._message_factory.is_registered(message.__class__): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py index 351701a19858..ca07d87bbe7d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py @@ -1,3 +1,4 @@ +import traceback from typing import List from pydantic import BaseModel @@ -6,6 +7,34 @@ from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage +class SerializableException(BaseModel): + """A serializable exception.""" + + error_type: str + """The type of error that occurred.""" + + error_message: str + """The error message that describes the error.""" + + traceback: str | None = None + """The traceback of the error, if available.""" + + @classmethod + def from_exception(cls, exc: Exception) -> "SerializableException": + """Create a GroupChatError from an exception.""" + return cls( + error_type=type(exc).__name__, + error_message=str(exc), + traceback="\n".join(traceback.format_exception(type(exc), exc, exc.__traceback__)), + ) + + def __str__(self) -> str: + """Return a string representation of the error, including the traceback if available.""" + if self.traceback: + return f"{self.error_type}: {self.error_message}\nTraceback:\n{self.traceback}" + return f"{self.error_type}: {self.error_message}" + + class GroupChatStart(BaseModel): """A request to start a group chat.""" @@ -39,6 +68,9 @@ class GroupChatTermination(BaseModel): message: StopMessage """The stop message that indicates the reason of termination.""" + error: SerializableException | None = None + """The error that occurred, if any.""" + class GroupChatReset(BaseModel): """A request to reset the agents in the group chat.""" @@ -56,3 +88,10 @@ class GroupChatResume(BaseModel): """A request to resume the group chat.""" ... + + +class GroupChatError(BaseModel): + """A message indicating that an error occurred in the group chat.""" + + error: SerializableException + """The error that occurred.""" diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 25d64ade8a43..6e04c2b8e50f 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -12,7 +12,7 @@ BaseChatAgent, CodeExecutorAgent, ) -from autogen_agentchat.base import Handoff, Response, TaskResult +from autogen_agentchat.base import Handoff, Response, TaskResult, TerminationCondition from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination from autogen_agentchat.messages import ( BaseAgentEvent, @@ -103,6 +103,26 @@ async def on_reset(self, cancellation_token: CancellationToken) -> None: self._last_message = None +class _FlakyTermination(TerminationCondition): + def __init__(self, raise_on_count: int) -> None: + self._raise_on_count = raise_on_count + self._count = 0 + + @property + def terminated(self) -> bool: + """Check if the termination condition has been reached""" + return False + + async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None: + self._count += 1 + if self._count == self._raise_on_count: + raise ValueError("I am a flaky termination...") + return None + + async def reset(self) -> None: + pass + + class _UnknownMessageType(BaseChatMessage): content: str @@ -285,7 +305,7 @@ async def test_round_robin_group_chat_unknown_agent_message_type() -> None: agent2 = _UnknownMessageTypeAgent("agent2", "I am an unknown message type agent") termination = TextMentionTermination("TERMINATE") team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination) - with pytest.raises(ValueError, match="Message type .*UnknownMessageType.* not registered"): + with pytest.raises(RuntimeError, match=".* Message type .*UnknownMessageType.* not registered"): await team1.run(task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")) @@ -457,10 +477,8 @@ async def test_round_robin_group_chat_with_resume_and_reset(runtime: AgentRuntim assert result.stop_reason is not None -# TODO: add runtime fixture for testing with custom runtime once the issue regarding -# hanging on exception is resolved. @pytest.mark.asyncio -async def test_round_robin_group_chat_with_exception_raised() -> None: +async def test_round_robin_group_chat_with_exception_raised_from_agent(runtime: AgentRuntime | None) -> None: agent_1 = _EchoAgent("agent_1", description="echo agent 1") agent_2 = _FlakyAgent("agent_2", description="echo agent 2") agent_3 = _EchoAgent("agent_3", description="echo agent 3") @@ -468,9 +486,29 @@ async def test_round_robin_group_chat_with_exception_raised() -> None: team = RoundRobinGroupChat( participants=[agent_1, agent_2, agent_3], termination_condition=termination, + runtime=runtime, + ) + + with pytest.raises(RuntimeError, match="I am a flaky agent..."): + await team.run( + task="Write a program that prints 'Hello, world!'", + ) + + +@pytest.mark.asyncio +async def test_round_robin_group_chat_with_exception_raised_from_termination_condition( + runtime: AgentRuntime | None, +) -> None: + agent_1 = _EchoAgent("agent_1", description="echo agent 1") + agent_2 = _FlakyAgent("agent_2", description="echo agent 2") + agent_3 = _EchoAgent("agent_3", description="echo agent 3") + team = RoundRobinGroupChat( + participants=[agent_1, agent_2, agent_3], + termination_condition=_FlakyTermination(raise_on_count=1), + runtime=runtime, ) - with pytest.raises(ValueError, match="I am a flaky agent..."): + with pytest.raises(Exception, match="I am a flaky termination..."): await team.run( task="Write a program that prints 'Hello, world!'", )