Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
965d9a4
FEAT: select group chat could using stream
SongChiYoung Apr 12, 2025
d42a382
Merge remote-tracking branch 'upstream/main' into feature/model_clien…
SongChiYoung Apr 13, 2025
7aff3b6
FIX: delete useless if block
SongChiYoung Apr 13, 2025
99194d4
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 15, 2025
7e24527
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 16, 2025
55a1341
Merge
SongChiYoung Apr 17, 2025
fda2d28
clean
SongChiYoung Apr 17, 2025
31d0d66
done - maybe need to adding testcase
SongChiYoung Apr 17, 2025
29f2c0f
Merge remote-tracking branch 'upstream/main' into feature/model_clien…
SongChiYoung Apr 17, 2025
cd9e000
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 18, 2025
d563032
FIX: adding full message of content of stream.
SongChiYoung Apr 19, 2025
3486d6d
Clean, Add test
SongChiYoung Apr 19, 2025
548171a
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 19, 2025
4abae5a
Apply suggestions from code review
ekzhu Apr 21, 2025
e68dbbf
Update python/packages/autogen-agentchat/src/autogen_agentchat/teams/…
ekzhu Apr 21, 2025
4fabfce
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
ekzhu Apr 21, 2025
b711739
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
SongChiYoung Apr 21, 2025
05c250a
Fix tests
ekzhu Apr 21, 2025
c9f7582
fix
ekzhu Apr 21, 2025
ef67c17
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
ekzhu Apr 21, 2025
f698469
Fix
ekzhu Apr 21, 2025
94a8f9b
Merge branch 'main' into feature/model_client_streaming_from_the_sele…
ekzhu Apr 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,18 @@
return str(self.content)


class SelectorEvent(BaseAgentEvent):
"""An event emitted from the `SelectorGroupChat`."""

content: str
"""The content of the event."""

type: Literal["SelectorEvent"] = "SelectorEvent"

def to_text(self) -> str:
return str(self.content)

Check warning on line 543 in python/packages/autogen-agentchat/src/autogen_agentchat/messages.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-agentchat/src/autogen_agentchat/messages.py#L543

Added line #L543 was not covered by tests


class MessageFactory:
""":meta private:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Sequence, Union, cast

from autogen_core import AgentRuntime, Component, ComponentModel
from autogen_core.models import AssistantMessage, ChatCompletionClient, ModelFamily, SystemMessage, UserMessage
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
ModelFamily,
SystemMessage,
UserMessage,
)
from pydantic import BaseModel
from typing_extensions import Self

Expand All @@ -16,6 +23,8 @@
BaseAgentEvent,
BaseChatMessage,
MessageFactory,
ModelClientStreamingChunkEvent,
SelectorEvent,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
Expand Down Expand Up @@ -56,6 +65,7 @@ def __init__(
max_selector_attempts: int,
candidate_func: Optional[CandidateFuncType],
emit_team_events: bool,
model_client_streaming: bool = False,
) -> None:
super().__init__(
name,
Expand All @@ -79,6 +89,7 @@ def __init__(
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._is_candidate_func_async = iscoroutinefunction(self._candidate_func)
self._model_client_streaming = model_client_streaming

async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
Expand Down Expand Up @@ -194,7 +205,26 @@ async def _select_speaker(self, roles: str, participants: List[str], history: st
num_attempts = 0
while num_attempts < max_attempts:
num_attempts += 1
response = await self._model_client.create(messages=select_speaker_messages)
if self._model_client_streaming:
chunk: CreateResult | str = ""
async for _chunk in self._model_client.create_stream(messages=select_speaker_messages):
chunk = _chunk
if self._emit_team_events:
if isinstance(chunk, str):
await self._output_message_queue.put(
ModelClientStreamingChunkEvent(content=cast(str, _chunk), source=self._name)
)
else:
assert isinstance(chunk, CreateResult)
assert isinstance(chunk.content, str)
await self._output_message_queue.put(
SelectorEvent(content=chunk.content, source=self._name)
)
# The last chunk must be CreateResult.
assert isinstance(chunk, CreateResult)
response = chunk
else:
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
select_speaker_messages.append(AssistantMessage(content=response.content, source="selector"))
# NOTE: we use all participant names to check for mentions, even if the previous speaker is not allowed.
Expand Down Expand Up @@ -281,6 +311,7 @@ class SelectorGroupChatConfig(BaseModel):
# selector_func: ComponentModel | None
max_selector_attempts: int = 3
emit_team_events: bool = False
model_client_streaming: bool = False


class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
Expand Down Expand Up @@ -311,6 +342,7 @@ class SelectorGroupChat(BaseGroupChat, Component[SelectorGroupChatConfig]):
selection using model. If the function returns an empty list or `None`, `SelectorGroupChat` will raise a `ValueError`.
This function is only used if `selector_func` is not set. The `allow_repeated_speaker` will be ignored if set.
emit_team_events (bool, optional): Whether to emit team events through :meth:`BaseGroupChat.run_stream`. Defaults to False.
model_client_streaming (bool, optional): Whether to use streaming for the model client. (This is useful for reasoning models like QwQ). Defaults to False.

Raises:
ValueError: If the number of participants is less than two or if the selector prompt is invalid.
Expand Down Expand Up @@ -453,6 +485,7 @@ def __init__(
candidate_func: Optional[CandidateFuncType] = None,
custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None,
emit_team_events: bool = False,
model_client_streaming: bool = False,
):
super().__init__(
participants,
Expand All @@ -473,6 +506,7 @@ def __init__(
self._selector_func = selector_func
self._max_selector_attempts = max_selector_attempts
self._candidate_func = candidate_func
self._model_client_streaming = model_client_streaming

def _create_group_chat_manager_factory(
self,
Expand Down Expand Up @@ -505,6 +539,7 @@ def _create_group_chat_manager_factory(
self._max_selector_attempts,
self._candidate_func,
self._emit_team_events,
self._model_client_streaming,
)

def _to_config(self) -> SelectorGroupChatConfig:
Expand All @@ -518,6 +553,7 @@ def _to_config(self) -> SelectorGroupChatConfig:
max_selector_attempts=self._max_selector_attempts,
# selector_func=self._selector_func.dump_component() if self._selector_func else None,
emit_team_events=self._emit_team_events,
model_client_streaming=self._model_client_streaming,
)

@classmethod
Expand All @@ -536,4 +572,5 @@ def _from_config(cls, config: SelectorGroupChatConfig) -> Self:
# if config.selector_func
# else None,
emit_team_events=config.emit_team_events,
model_client_streaming=config.model_client_streaming,
)
60 changes: 59 additions & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
CodeExecutorAgent,
)
from autogen_agentchat.base import Handoff, Response, TaskResult, TerminationCondition
from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination
from autogen_agentchat.conditions import (
HandoffTermination,
MaxMessageTermination,
StopMessageTermination,
TextMentionTermination,
)
from autogen_agentchat.messages import (
BaseAgentEvent,
BaseChatMessage,
HandoffMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
SelectorEvent,
SelectSpeakerEvent,
StopMessage,
StructuredMessage,
Expand Down Expand Up @@ -1698,3 +1705,54 @@ async def test_structured_message_state_roundtrip(runtime: AgentRuntime | None)
)

assert manager1._message_thread == manager2._message_thread # pyright: ignore


@pytest.mark.asyncio
async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> None:
model_client = ReplayChatCompletionClient(
["the agent should be agent2"],
)
agent2 = _StopAgent("agent2", description="stop agent 2", stop_at=0)
agent3 = _EchoAgent("agent3", description="echo agent 3")
termination = StopMessageTermination()
team = SelectorGroupChat(
participants=[agent2, agent3],
model_client=model_client,
termination_condition=termination,
runtime=runtime,
emit_team_events=True,
model_client_streaming=True,
)
result = await team.run(
task="Write a program that prints 'Hello, world!'",
)

assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], SelectorEvent)
assert isinstance(result.messages[2], SelectSpeakerEvent)
assert isinstance(result.messages[3], StopMessage)

assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
assert result.messages[1].content == "the agent should be agent2"
assert result.messages[2].content == ["agent2"]
assert result.messages[3].source == "agent2"
assert result.stop_reason is not None and result.stop_reason == "Stop message received"

# Test streaming
await team.reset()
model_client.reset()
index = 0
streaming: List[str] = []
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
if isinstance(message, TaskResult):
assert message == result
elif isinstance(message, ModelClientStreamingChunkEvent):
streaming.append(message.content)
else:
if streaming:
assert isinstance(message, SelectorEvent)
assert message.content == "".join([chunk for chunk in streaming])
streaming = []
assert message == result.messages[index]
index += 1
Loading