Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence
Expand All @@ -15,7 +14,6 @@
)
from pydantic import BaseModel, ValidationError

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import (
BaseAgentEvent,
Expand All @@ -27,11 +25,16 @@
)
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatPause, GroupChatReset, GroupChatResume, GroupChatStart, GroupChatTermination
from ._events import (
GroupChatPause,
GroupChatReset,
GroupChatResume,
GroupChatStart,
GroupChatTermination,
SerializableException,
)
from ._sequential_routed_agent import SequentialRoutedAgent

event_logger = logging.getLogger(EVENT_LOGGER_NAME)


class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
"""The base class for group chat teams.
Expand Down Expand Up @@ -447,13 +450,26 @@ async def stop_runtime() -> None:
try:
# This will propagate any exceptions raised.
await self._runtime.stop_when_idle()
finally:
# Put a termination message in the queue to indicate that the group chat is stopped for whatever reason
# but not due to an exception.
await self._output_message_queue.put(
GroupChatTermination(
message=StopMessage(
content="The group chat is stopped.", source=self._group_chat_manager_name
)
)
)
except Exception as e:
# Stop the consumption of messages and end the stream.
# NOTE: we also need to put a GroupChatTermination event here because when the group chat
# NOTE: we also need to put a GroupChatTermination event here because when the runtime
# has an exception, the group chat manager may not be able to put a GroupChatTermination event in the queue.
# This may not be necessary if the group chat manager is able to handle the exception and put the event in the queue.
await self._output_message_queue.put(
GroupChatTermination(
message=StopMessage(content="Exception occurred.", source=self._group_chat_manager_name)
message=StopMessage(
content="An exception occurred in the runtime.", source=self._group_chat_manager_name
),
error=SerializableException.from_exception(e),
)
)

Expand Down Expand Up @@ -481,11 +497,10 @@ async def stop_runtime() -> None:
# Wait for the next message, this will raise an exception if the task is cancelled.
message = await message_future
if isinstance(message, GroupChatTermination):
# If the message is None, it means the group chat has terminated.
# TODO: how do we handle termination when the runtime is not embedded
# and there is an exception in the group chat?
# The group chat manager may not be able to put a GroupChatTermination event in the queue,
# and this loop will never end.
# If the message contains an error, we need to raise it here.
# This will stop the team and propagate the error.
if message.error is not None:
raise RuntimeError(str(message.error))
stop_reason = message.message.content
break
yield message
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from ...messages import BaseAgentEvent, BaseChatMessage, MessageFactory, StopMessage
from ._events import (
GroupChatAgentResponse,
GroupChatError,
GroupChatMessage,
GroupChatPause,
GroupChatRequestPublish,
GroupChatReset,
GroupChatResume,
GroupChatStart,
GroupChatTermination,
SerializableException,
)
from ._sequential_routed_agent import SequentialRoutedAgent

Expand Down Expand Up @@ -140,58 +142,65 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No

@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
# Append the message to the message thread and construct the delta.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
self._message_thread.append(inner_message)
delta.append(inner_message)
self._message_thread.append(message.agent_response.chat_message)
delta.append(message.agent_response.chat_message)

# Check if the conversation should be terminated.
if self._termination_condition is not None:
stop_message = await self._termination_condition(delta)
if stop_message is not None:
# Reset the termination conditions and turn count.
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return

# Increment the turn count.
self._current_turn += 1
# Check if the maximum number of turns has been reached.
if self._max_turns is not None:
if self._current_turn >= self._max_turns:
stop_message = StopMessage(
content=f"Maximum number of turns {self._max_turns} reached.",
source=self._name,
)
# Reset the termination conditions and turn count.
if self._termination_condition is not None:
try:
# Append the message to the message thread and construct the delta.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
self._message_thread.append(inner_message)
delta.append(inner_message)
self._message_thread.append(message.agent_response.chat_message)
delta.append(message.agent_response.chat_message)

# Check if the conversation should be terminated.
if self._termination_condition is not None:
stop_message = await self._termination_condition(delta)
if stop_message is not None:
# Reset the termination conditions and turn count.
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return

# Select a speaker to continue the conversation.
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)
# Increment the turn count.
self._current_turn += 1
# Check if the maximum number of turns has been reached.
if self._max_turns is not None:
if self._current_turn >= self._max_turns:
stop_message = StopMessage(
content=f"Maximum number of turns {self._max_turns} reached.",
source=self._name,
)
# Reset the termination conditions and turn count.
if self._termination_condition is not None:
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
# Stop the group chat.
return

# Select a speaker to continue the conversation.
speaker_name_future = asyncio.ensure_future(self.select_speaker(self._message_thread))
# Link the select speaker future to the cancellation token.
ctx.cancellation_token.link_future(speaker_name_future)
speaker_name = await speaker_name_future
if speaker_name not in self._participant_name_to_topic_type:
raise RuntimeError(f"Speaker {speaker_name} not found in participant names.")
speaker_topic_type = self._participant_name_to_topic_type[speaker_name]
await self.publish_message(
GroupChatRequestPublish(),
topic_id=DefaultTopicId(type=speaker_topic_type),
cancellation_token=ctx.cancellation_token,
)
except Exception as e:
# Handle the exception and signal termination with an error.
error = SerializableException.from_exception(e)
await self._signal_termination_with_error(error)
# Raise the exception to the runtime.
raise

async def _signal_termination(self, message: StopMessage) -> None:
termination_event = GroupChatTermination(message=message)
Expand All @@ -203,11 +212,28 @@ async def _signal_termination(self, message: StopMessage) -> None:
# Put the termination event in the output message queue.
await self._output_message_queue.put(termination_event)

async def _signal_termination_with_error(self, error: SerializableException) -> None:
termination_event = GroupChatTermination(
message=StopMessage(content="An error occurred in the group chat.", source=self._name), error=error
)
# Log the termination event.
await self.publish_message(
termination_event,
topic_id=DefaultTopicId(type=self._output_topic_type),
)
# Put the termination event in the output message queue.
await self._output_message_queue.put(termination_event)

@event
async def handle_group_chat_message(self, message: GroupChatMessage, ctx: MessageContext) -> None:
"""Handle a group chat message by appending the content to its output message queue."""
await self._output_message_queue.put(message.message)

@event
async def handle_group_chat_error(self, message: GroupChatError, ctx: MessageContext) -> None:
"""Handle a group chat error by logging the error and signaling termination."""
await self._signal_termination_with_error(message.error)

@rpc
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
"""Reset the group chat manager. Calling :meth:`reset` to reset the group chat manager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from ...state import ChatAgentContainerState
from ._events import (
GroupChatAgentResponse,
GroupChatError,
GroupChatMessage,
GroupChatPause,
GroupChatRequestPublish,
GroupChatReset,
GroupChatResume,
GroupChatStart,
SerializableException,
)
from ._sequential_routed_agent import SequentialRoutedAgent

Expand Down Expand Up @@ -71,24 +73,36 @@ async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> No
async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageContext) -> None:
"""Handle a content request event by passing the messages in the buffer
to the delegate agent and publish the response."""
# Pass the messages in the buffer to the delegate agent.
response: Response | None = None
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
if isinstance(msg, Response):
await self._log_message(msg.chat_message)
response = msg
else:
await self._log_message(msg)
if response is None:
raise ValueError("The agent did not produce a final response. Check the agent's on_messages_stream method.")

# Publish the response to the group chat.
self._message_buffer.clear()
await self.publish_message(
GroupChatAgentResponse(agent_response=response),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
try:
# Pass the messages in the buffer to the delegate agent.
response: Response | None = None
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
if isinstance(msg, Response):
await self._log_message(msg.chat_message)
response = msg
else:
await self._log_message(msg)
if response is None:
raise ValueError(
"The agent did not produce a final response. Check the agent's on_messages_stream method."
)
# Publish the response to the group chat.
self._message_buffer.clear()
await self.publish_message(
GroupChatAgentResponse(agent_response=response),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
except Exception as e:
# Publish the error to the group chat.
error_message = SerializableException.from_exception(e)
await self.publish_message(
GroupChatError(error=error_message),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
# Raise the error to the runtime.
raise

def _buffer_message(self, message: BaseChatMessage) -> None:
if not self._message_factory.is_registered(message.__class__):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
from typing import List

from pydantic import BaseModel
Expand All @@ -6,6 +7,34 @@
from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage


class SerializableException(BaseModel):
"""A serializable exception."""

error_type: str
"""The type of error that occurred."""

error_message: str
"""The error message that describes the error."""

traceback: str | None = None
"""The traceback of the error, if available."""

@classmethod
def from_exception(cls, exc: Exception) -> "SerializableException":
"""Create a GroupChatError from an exception."""
return cls(
error_type=type(exc).__name__,
error_message=str(exc),
traceback="\n".join(traceback.format_exception(type(exc), exc, exc.__traceback__)),
)

def __str__(self) -> str:
"""Return a string representation of the error, including the traceback if available."""
if self.traceback:
return f"{self.error_type}: {self.error_message}\nTraceback:\n{self.traceback}"
return f"{self.error_type}: {self.error_message}"


class GroupChatStart(BaseModel):
"""A request to start a group chat."""

Expand Down Expand Up @@ -39,6 +68,9 @@ class GroupChatTermination(BaseModel):
message: StopMessage
"""The stop message that indicates the reason of termination."""

error: SerializableException | None = None
"""The error that occurred, if any."""


class GroupChatReset(BaseModel):
"""A request to reset the agents in the group chat."""
Expand All @@ -56,3 +88,10 @@ class GroupChatResume(BaseModel):
"""A request to resume the group chat."""

...


class GroupChatError(BaseModel):
"""A message indicating that an error occurred in the group chat."""

error: SerializableException
"""The error that occurred."""
Loading
Loading