Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self:
model_client=ChatCompletionClient.load_component(config.model_client),
tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None,
handoffs=config.handoffs,
model_context=None,
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None,
description=config.description,
system_message=config.system_message,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,13 @@ def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
"""The types of messages that the code executor agent produces."""
return (TextMessage,)

@property
def model_context(self) -> ChatCompletionContext:
"""
The model context in use by the agent.
"""
return self._model_context

async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
Expand Down Expand Up @@ -566,7 +573,7 @@ def _from_config(cls, config: CodeExecutorAgentConfig) -> Self:
sources=config.sources,
system_message=config.system_message,
model_client_stream=config.model_client_stream,
model_context=None,
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ def _to_config(self) -> SocietyOfMindAgentConfig:
description=self.description,
instruction=self._instruction,
response_prompt=self._response_prompt,
model_context=self._model_context.dump_component(),
)

@classmethod
Expand All @@ -299,4 +300,5 @@ def _from_config(cls, config: SocietyOfMindAgentConfig) -> Self:
description=config.description or cls.DEFAULT_DESCRIPTION,
instruction=config.instruction or cls.DEFAULT_INSTRUCTION,
response_prompt=config.response_prompt or cls.DEFAULT_RESPONSE_PROMPT,
model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None,
)
126 changes: 126 additions & 0 deletions python/packages/autogen-agentchat/tests/test_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import pytest
from autogen_agentchat.agents import (
AssistantAgent,
CodeExecutorAgent,
SocietyOfMindAgent,
)
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_core.model_context import (
BufferedChatCompletionContext,
ChatCompletionContext,
HeadAndTailChatCompletionContext,
TokenLimitedChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.replay import ReplayChatCompletionClient


@pytest.mark.parametrize(
"model_context_class",
[
UnboundedChatCompletionContext(),
BufferedChatCompletionContext(buffer_size=5),
TokenLimitedChatCompletionContext(model_client=ReplayChatCompletionClient([]), token_limit=5),
HeadAndTailChatCompletionContext(head_size=3, tail_size=2),
],
)
def test_serialize_and_deserialize_model_context_on_assistant_agent(model_context_class: ChatCompletionContext) -> None:
"""Test the serialization and deserialization of the message context on the AssistantAgent."""
agent = AssistantAgent(
name="assistant",
model_client=ReplayChatCompletionClient([]),
description="An assistant agent.",
model_context=model_context_class,
)

# Serialize the agent
serialized_agent = agent.dump_component()
# Deserialize the agent
deserialized_agent = AssistantAgent.load_component(serialized_agent)

# Check that the deserialized agent has the same model context as the original agent
original_model_context = agent.model_context
deserialized_model_context = deserialized_agent.model_context

assert isinstance(original_model_context, type(deserialized_model_context))
assert isinstance(deserialized_model_context, type(original_model_context))
assert original_model_context.dump_component() == deserialized_model_context.dump_component()


@pytest.mark.parametrize(
"model_context_class",
[
UnboundedChatCompletionContext(),
BufferedChatCompletionContext(buffer_size=5),
TokenLimitedChatCompletionContext(model_client=ReplayChatCompletionClient([]), token_limit=5),
HeadAndTailChatCompletionContext(head_size=3, tail_size=2),
],
)
def test_serialize_and_deserialize_model_context_on_society_of_mind_agent(
model_context_class: ChatCompletionContext,
) -> None:
"""Test the serialization and deserialization of the message context on the AssistantAgent."""
agent1 = AssistantAgent(
name="assistant1", model_client=ReplayChatCompletionClient([]), description="An assistant agent."
)
agent2 = AssistantAgent(
name="assistant2", model_client=ReplayChatCompletionClient([]), description="An assistant agent."
)
team = RoundRobinGroupChat(
participants=[agent1, agent2],
)
agent = SocietyOfMindAgent(
name="assistant",
model_client=ReplayChatCompletionClient([]),
description="An assistant agent.",
team=team,
model_context=model_context_class,
)

# Serialize the agent
serialized_agent = agent.dump_component()
# Deserialize the agent
deserialized_agent = SocietyOfMindAgent.load_component(serialized_agent)

# Check that the deserialized agent has the same model context as the original agent
original_model_context = agent.model_context
deserialized_model_context = deserialized_agent.model_context

assert isinstance(original_model_context, type(deserialized_model_context))
assert isinstance(deserialized_model_context, type(original_model_context))
assert original_model_context.dump_component() == deserialized_model_context.dump_component()


@pytest.mark.parametrize(
"model_context_class",
[
UnboundedChatCompletionContext(),
BufferedChatCompletionContext(buffer_size=5),
TokenLimitedChatCompletionContext(model_client=ReplayChatCompletionClient([]), token_limit=5),
HeadAndTailChatCompletionContext(head_size=3, tail_size=2),
],
)
def test_serialize_and_deserialize_model_context_on_code_executor_agent(
model_context_class: ChatCompletionContext,
) -> None:
"""Test the serialization and deserialization of the message context on the AssistantAgent."""
agent = CodeExecutorAgent(
name="assistant",
code_executor=LocalCommandLineCodeExecutor(),
description="An assistant agent.",
model_context=model_context_class,
)

# Serialize the agent
serialized_agent = agent.dump_component()
# Deserialize the agent
deserialized_agent = CodeExecutorAgent.load_component(serialized_agent)

# Check that the deserialized agent has the same model context as the original agent
original_model_context = agent.model_context
deserialized_model_context = deserialized_agent.model_context

assert isinstance(original_model_context, type(deserialized_model_context))
assert isinstance(deserialized_model_context, type(original_model_context))
assert original_model_context.dump_component() == deserialized_model_context.dump_component()
Loading