diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index f5ba404d5fe4..f066e082236f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -1337,7 +1337,7 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self: model_client=ChatCompletionClient.load_component(config.model_client), tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None, handoffs=config.handoffs, - model_context=None, + model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None, description=config.description, system_message=config.system_message, diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py index 8006164e5dec..60ffaa60cb37 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py @@ -369,6 +369,13 @@ def produced_message_types(self) -> Sequence[type[BaseChatMessage]]: """The types of messages that the code executor agent produces.""" return (TextMessage,) + @property + def model_context(self) -> ChatCompletionContext: + """ + The model context in use by the agent. + """ + return self._model_context + async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: async for message in self.on_messages_stream(messages, cancellation_token): if isinstance(message, Response): @@ -566,7 +573,7 @@ def _from_config(cls, config: CodeExecutorAgentConfig) -> Self: sources=config.sources, system_message=config.system_message, model_client_stream=config.model_client_stream, - model_context=None, + model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, ) @staticmethod diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py index 5bcf8365174a..9ed9606362c6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py @@ -286,6 +286,7 @@ def _to_config(self) -> SocietyOfMindAgentConfig: description=self.description, instruction=self._instruction, response_prompt=self._response_prompt, + model_context=self._model_context.dump_component(), ) @classmethod @@ -299,4 +300,5 @@ def _from_config(cls, config: SocietyOfMindAgentConfig) -> Self: description=config.description or cls.DEFAULT_DESCRIPTION, instruction=config.instruction or cls.DEFAULT_INSTRUCTION, response_prompt=config.response_prompt or cls.DEFAULT_RESPONSE_PROMPT, + model_context=ChatCompletionContext.load_component(config.model_context) if config.model_context else None, ) diff --git a/python/packages/autogen-agentchat/tests/test_agent.py b/python/packages/autogen-agentchat/tests/test_agent.py new file mode 100644 index 000000000000..605f85448aa6 --- /dev/null +++ b/python/packages/autogen-agentchat/tests/test_agent.py @@ -0,0 +1,126 @@ +import pytest +from autogen_agentchat.agents import ( + AssistantAgent, + CodeExecutorAgent, + SocietyOfMindAgent, +) +from autogen_agentchat.teams import RoundRobinGroupChat +from autogen_core.model_context import ( + BufferedChatCompletionContext, + ChatCompletionContext, + HeadAndTailChatCompletionContext, + TokenLimitedChatCompletionContext, + UnboundedChatCompletionContext, +) +from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor +from autogen_ext.models.replay import ReplayChatCompletionClient + + +@pytest.mark.parametrize( + "model_context_class", + [ + UnboundedChatCompletionContext(), + BufferedChatCompletionContext(buffer_size=5), + TokenLimitedChatCompletionContext(model_client=ReplayChatCompletionClient([]), token_limit=5), + HeadAndTailChatCompletionContext(head_size=3, tail_size=2), + ], +) +def test_serialize_and_deserialize_model_context_on_assistant_agent(model_context_class: ChatCompletionContext) -> None: + """Test the serialization and deserialization of the message context on the AssistantAgent.""" + agent = AssistantAgent( + name="assistant", + model_client=ReplayChatCompletionClient([]), + description="An assistant agent.", + model_context=model_context_class, + ) + + # Serialize the agent + serialized_agent = agent.dump_component() + # Deserialize the agent + deserialized_agent = AssistantAgent.load_component(serialized_agent) + + # Check that the deserialized agent has the same model context as the original agent + original_model_context = agent.model_context + deserialized_model_context = deserialized_agent.model_context + + assert isinstance(original_model_context, type(deserialized_model_context)) + assert isinstance(deserialized_model_context, type(original_model_context)) + assert original_model_context.dump_component() == deserialized_model_context.dump_component() + + +@pytest.mark.parametrize( + "model_context_class", + [ + UnboundedChatCompletionContext(), + BufferedChatCompletionContext(buffer_size=5), + TokenLimitedChatCompletionContext(model_client=ReplayChatCompletionClient([]), token_limit=5), + HeadAndTailChatCompletionContext(head_size=3, tail_size=2), + ], +) +def test_serialize_and_deserialize_model_context_on_society_of_mind_agent( + model_context_class: ChatCompletionContext, +) -> None: + """Test the serialization and deserialization of the message context on the AssistantAgent.""" + agent1 = AssistantAgent( + name="assistant1", model_client=ReplayChatCompletionClient([]), description="An assistant agent." + ) + agent2 = AssistantAgent( + name="assistant2", model_client=ReplayChatCompletionClient([]), description="An assistant agent." + ) + team = RoundRobinGroupChat( + participants=[agent1, agent2], + ) + agent = SocietyOfMindAgent( + name="assistant", + model_client=ReplayChatCompletionClient([]), + description="An assistant agent.", + team=team, + model_context=model_context_class, + ) + + # Serialize the agent + serialized_agent = agent.dump_component() + # Deserialize the agent + deserialized_agent = SocietyOfMindAgent.load_component(serialized_agent) + + # Check that the deserialized agent has the same model context as the original agent + original_model_context = agent.model_context + deserialized_model_context = deserialized_agent.model_context + + assert isinstance(original_model_context, type(deserialized_model_context)) + assert isinstance(deserialized_model_context, type(original_model_context)) + assert original_model_context.dump_component() == deserialized_model_context.dump_component() + + +@pytest.mark.parametrize( + "model_context_class", + [ + UnboundedChatCompletionContext(), + BufferedChatCompletionContext(buffer_size=5), + TokenLimitedChatCompletionContext(model_client=ReplayChatCompletionClient([]), token_limit=5), + HeadAndTailChatCompletionContext(head_size=3, tail_size=2), + ], +) +def test_serialize_and_deserialize_model_context_on_code_executor_agent( + model_context_class: ChatCompletionContext, +) -> None: + """Test the serialization and deserialization of the message context on the AssistantAgent.""" + agent = CodeExecutorAgent( + name="assistant", + code_executor=LocalCommandLineCodeExecutor(), + description="An assistant agent.", + model_context=model_context_class, + ) + + # Serialize the agent + serialized_agent = agent.dump_component() + # Deserialize the agent + deserialized_agent = CodeExecutorAgent.load_component(serialized_agent) + + # Check that the deserialized agent has the same model context as the original agent + original_model_context = agent.model_context + deserialized_model_context = deserialized_agent.model_context + + assert isinstance(original_model_context, type(deserialized_model_context)) + assert isinstance(deserialized_model_context, type(original_model_context)) + assert original_model_context.dump_component() == deserialized_model_context.dump_component()