Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@
class Team(ABC, TaskRunner, ComponentBase[BaseModel]):
component_type = "team"

@property
@abstractmethod
def name(self) -> str:
"""The name of the team. This is used by team to uniquely identify itself
in a larger team of teams."""
...

@property
@abstractmethod
def description(self) -> str:
"""A description of the team. This is used to provide context about the
team and its purpose to its parent orchestrator."""
...

@abstractmethod
async def reset(self) -> None:
"""Reset the team and all its participants to its initial state."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,34 @@
class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):
"""The base class for group chat teams.

In a group chat team, participants share context by publishing their messages
to all other participants.

If an :class:`~autogen_agentchat.base.ChatAgent` is a participant,
the :class:`~autogen_agentchat.messages.BaseChatMessage` from the agent response's
:attr:`~autogen_agentchat.base.Response.chat_message` will be published
to other participants in the group chat.

If a :class:`~autogen_agentchat.base.Team` is a participant,
the :class:`~autogen_agentchat.messages.BaseChatMessage`
from the team result' :attr:`~autogen_agentchat.base.TaskResult.messages` will be published
to other participants in the group chat.

To implement a group chat team, first create a subclass of :class:`BaseGroupChatManager` and then
create a subclass of :class:`BaseGroupChat` that uses the group chat manager.

This base class provides the mapping between the agents of the AgentChat API
and the agent runtime of the Core API, and handles high-level features like
running, pausing, resuming, and resetting the team.
"""

component_type = "team"

def __init__(
self,
participants: List[ChatAgent],
name: str,
description: str,
participants: List[ChatAgent | Team],
group_chat_manager_name: str,
group_chat_manager_class: type[SequentialRoutedAgent],
termination_condition: TerminationCondition | None = None,
Expand All @@ -57,6 +76,8 @@ def __init__(
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
):
self._name = name
self._description = description
if len(participants) == 0:
raise ValueError("At least one participant is required.")
if len(participants) != len(set(participant.name for participant in participants)):
Expand All @@ -71,14 +92,15 @@ def __init__(
self._message_factory.register(message_type)

for agent in participants:
for message_type in agent.produced_message_types:
try:
is_registered = self._message_factory.is_registered(message_type) # type: ignore[reportUnknownArgumentType]
if issubclass(message_type, StructuredMessage) and not is_registered:
self._message_factory.register(message_type) # type: ignore[reportUnknownArgumentType]
except TypeError:
# Not a class or not a valid subclassable type (skip)
pass
if isinstance(agent, ChatAgent):
for message_type in agent.produced_message_types:
try:
is_registered = self._message_factory.is_registered(message_type) # type: ignore[reportUnknownArgumentType]
if issubclass(message_type, StructuredMessage) and not is_registered:
self._message_factory.register(message_type) # type: ignore[reportUnknownArgumentType]
except TypeError:
# Not a class or not a valid subclassable type (skip)
pass

# The team ID is a UUID that is used to identify the team and its participants
# in the agent runtime. It is used to create unique topic types for each participant.
Expand Down Expand Up @@ -128,6 +150,16 @@ def __init__(
# Flag to track if the team events should be emitted.
self._emit_team_events = emit_team_events

@property
def name(self) -> str:
"""The name of the group chat team."""
return self._name

@property
def description(self) -> str:
"""A description of the group chat team."""
return self._description

@abstractmethod
def _create_group_chat_manager_factory(
self,
Expand All @@ -147,7 +179,7 @@ def _create_participant_factory(
self,
parent_topic_type: str,
output_topic_type: str,
agent: ChatAgent,
agent: ChatAgent | Team,
message_factory: MessageFactory,
) -> Callable[[], ChatAgentContainer]:
def _factory() -> ChatAgentContainer:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
GroupChatReset,
GroupChatResume,
GroupChatStart,
GroupChatTeamResponse,
GroupChatTermination,
SerializableException,
)
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
sequential_message_types=[
GroupChatStart,
GroupChatAgentResponse,
GroupChatTeamResponse,
GroupChatMessage,
GroupChatReset,
],
Expand Down Expand Up @@ -130,20 +132,25 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
await self._transition_to_next_speakers(ctx.cancellation_token)

@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
async def handle_agent_response(
self, message: GroupChatAgentResponse | GroupChatTeamResponse, ctx: MessageContext
) -> None:
try:
# Construct the detla from the agent response.
delta: List[BaseAgentEvent | BaseChatMessage] = []
if message.agent_response.inner_messages is not None:
for inner_message in message.agent_response.inner_messages:
delta.append(inner_message)
delta.append(message.agent_response.chat_message)
if isinstance(message, GroupChatAgentResponse):
if message.response.inner_messages is not None:
for inner_message in message.response.inner_messages:
delta.append(inner_message)
delta.append(message.response.chat_message)
else:
delta.extend(message.result.messages)

# Append the messages to the message thread.
await self.update_message_thread(delta)

# Remove the agent from the active speakers list.
self._active_speakers.remove(message.agent_name)
self._active_speakers.remove(message.name)
if len(self._active_speakers) > 0:
# If there are still active speakers, return without doing anything.
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, MessageFactory

from ...base import ChatAgent, Response
from ...base import ChatAgent, Response, TaskResult, Team
from ...state import ChatAgentContainerState
from ._events import (
GroupChatAgentResponse,
Expand All @@ -15,26 +15,27 @@
GroupChatReset,
GroupChatResume,
GroupChatStart,
GroupChatTeamResponse,
SerializableException,
)
from ._sequential_routed_agent import SequentialRoutedAgent


class ChatAgentContainer(SequentialRoutedAgent):
"""A core agent class that delegates message handling to an
:class:`autogen_agentchat.base.ChatAgent` so that it can be used in a
group chat team.
:class:`autogen_agentchat.base.ChatAgent` or :class:`autogen_agentchat.base.Team`
so that it can be used in a group chat team.

Args:
parent_topic_type (str): The topic type of the parent orchestrator.
output_topic_type (str): The topic type for the output.
agent (ChatAgent): The agent to delegate message handling to.
agent (ChatAgent | Team): The agent or team to delegate message handling to.
message_factory (MessageFactory): The message factory to use for
creating messages from JSON data.
"""

def __init__(
self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent, message_factory: MessageFactory
self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent | Team, message_factory: MessageFactory
) -> None:
super().__init__(
description=agent.description,
Expand All @@ -43,6 +44,7 @@ def __init__(
GroupChatRequestPublish,
GroupChatReset,
GroupChatAgentResponse,
GroupChatTeamResponse,
],
)
self._parent_topic_type = parent_topic_type
Expand All @@ -61,40 +63,50 @@ async def handle_start(self, message: GroupChatStart, ctx: MessageContext) -> No
@event
async def handle_agent_response(self, message: GroupChatAgentResponse, ctx: MessageContext) -> None:
"""Handle an agent response event by appending the content to the buffer."""
self._buffer_message(message.agent_response.chat_message)
self._buffer_message(message.response.chat_message)

@event
async def handle_team_response(self, message: GroupChatTeamResponse, ctx: MessageContext) -> None:
"""Handle a team response event by appending the content to the buffer."""
for msg in message.result.messages:
if isinstance(msg, BaseChatMessage):
self._buffer_message(msg)

@rpc
async def handle_reset(self, message: GroupChatReset, ctx: MessageContext) -> None:
"""Handle a reset event by resetting the agent."""
self._message_buffer.clear()
await self._agent.on_reset(ctx.cancellation_token)
if isinstance(self._agent, Team):
# If the agent is a team, reset the team.
await self._agent.reset()
else:
await self._agent.on_reset(ctx.cancellation_token)

@event
async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageContext) -> None:
"""Handle a content request event by passing the messages in the buffer
to the delegate agent and publish the response."""
with trace_invoke_agent_span(
agent_name=self._agent.name,
agent_description=self._agent.description,
agent_id=str(self.id),
):
if isinstance(self._agent, Team):
try:
# Pass the messages in the buffer to the delegate agent.
response: Response | None = None
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
if isinstance(msg, Response):
await self._log_message(msg.chat_message)
response = msg
stream = self._agent.run_stream(
task=self._message_buffer,
cancellation_token=ctx.cancellation_token,
output_task_messages=False,
)
result: TaskResult | None = None
async for team_event in stream:
if isinstance(team_event, TaskResult):
result = team_event
else:
await self._log_message(msg)
if response is None:
raise ValueError(
"The agent did not produce a final response. Check the agent's on_messages_stream method."
await self._log_message(team_event)
if result is None:
raise RuntimeError(
"The team did not produce a final TaskResult. Check the team's run_stream method."
)
# Publish the response to the group chat.
self._message_buffer.clear()
# Publish the team response to the group chat.
await self.publish_message(
GroupChatAgentResponse(agent_response=response, agent_name=self._agent.name),
GroupChatTeamResponse(result=result, name=self._agent.name),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
Expand All @@ -108,6 +120,43 @@ async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageCon
)
# Raise the error to the runtime.
raise
else:
# If the agent is not a team, handle it as a single agent.
with trace_invoke_agent_span(
agent_name=self._agent.name,
agent_description=self._agent.description,
agent_id=str(self.id),
):
try:
# Pass the messages in the buffer to the delegate agent.
response: Response | None = None
async for msg in self._agent.on_messages_stream(self._message_buffer, ctx.cancellation_token):
if isinstance(msg, Response):
await self._log_message(msg.chat_message)
response = msg
else:
await self._log_message(msg)
if response is None:
raise RuntimeError(
"The agent did not produce a final response. Check the agent's on_messages_stream method."
)
# Publish the response to the group chat.
self._message_buffer.clear()
await self.publish_message(
GroupChatAgentResponse(response=response, name=self._agent.name),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
except Exception as e:
# Publish the error to the group chat.
error_message = SerializableException.from_exception(e)
await self.publish_message(
GroupChatError(error=error_message),
topic_id=DefaultTopicId(type=self._parent_topic_type),
cancellation_token=ctx.cancellation_token,
)
# Raise the error to the runtime.
raise

def _buffer_message(self, message: BaseChatMessage) -> None:
if not self._message_factory.is_registered(message.__class__):
Expand All @@ -127,12 +176,20 @@ async def _log_message(self, message: BaseAgentEvent | BaseChatMessage) -> None:
@rpc
async def handle_pause(self, message: GroupChatPause, ctx: MessageContext) -> None:
"""Handle a pause event by pausing the agent."""
await self._agent.on_pause(ctx.cancellation_token)
if isinstance(self._agent, Team):
# If the agent is a team, pause the team.
await self._agent.pause()
else:
await self._agent.on_pause(ctx.cancellation_token)

@rpc
async def handle_resume(self, message: GroupChatResume, ctx: MessageContext) -> None:
"""Handle a resume event by resuming the agent."""
await self._agent.on_resume(ctx.cancellation_token)
if isinstance(self._agent, Team):
# If the agent is a team, resume the team.
await self._agent.resume()
else:
await self._agent.on_resume(ctx.cancellation_token)

async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in agent container: {type(message)}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel

from ...base import Response
from ...base import Response, TaskResult
from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage


Expand Down Expand Up @@ -48,13 +48,23 @@ class GroupChatStart(BaseModel):
class GroupChatAgentResponse(BaseModel):
"""A response published to a group chat."""

agent_response: Response
response: Response
"""The response from an agent."""

agent_name: str
name: str
"""The name of the agent that produced the response."""


class GroupChatTeamResponse(BaseModel):
"""A response published to a group chat from a team."""

result: TaskResult
"""The result from a team."""

name: str
"""The name of the team that produced the response."""


class GroupChatRequestPublish(BaseModel):
"""A request to publish a message to a group chat."""

Expand Down
Loading
Loading