Skip to content

[FEAT] summarized chat completion context #6217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
364ddf6
FEAT: build Just draft
SongChiYoung Apr 5, 2025
29f29de
FEAT: End of design summarized cat completion
SongChiYoung Apr 5, 2025
03353ef
FEAT: add conditions
SongChiYoung Apr 6, 2025
fda9bed
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 6, 2025
2a651a5
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 6, 2025
e88443b
FEAT: Add TokenUsageMessageCompletion
SongChiYoung Apr 6, 2025
33e9595
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 7, 2025
ec63df3
feat: maybe, add all of conditions
SongChiYoung Apr 7, 2025
9e4fef9
FIX: fix errors
SongChiYoung Apr 7, 2025
f772ed4
FEAT: buffered_summary add for included summary func
SongChiYoung Apr 7, 2025
d063f15
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 7, 2025
b337c13
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 10, 2025
df60122
FEAT: add bufferd_summary_context
SongChiYoung Apr 10, 2025
37ef9bf
FEAT: serialize aware
SongChiYoung Apr 10, 2025
c536f8f
FEAT: pyright and mypy aware
SongChiYoung Apr 10, 2025
88894d3
FIX: Docstring fixed
SongChiYoung Apr 10, 2025
d8d18a8
TEST: add summary model context test
SongChiYoung Apr 10, 2025
90486e0
CHOR: test code pyright/mypy
SongChiYoung Apr 10, 2025
b8ddf8e
FEAT: add summary engines test cases
SongChiYoung Apr 10, 2025
5686a69
FEAT: test cov done
SongChiYoung Apr 10, 2025
daf0ee8
Merge remote-tracking branch 'upstream/main' into feature/_summarized…
SongChiYoung Apr 11, 2025
5ac3fb9
Merge branch 'main' into feature/_summarized_chat_completion_context
SongChiYoung Apr 12, 2025
4e3c53d
FEAT: LLM summary engine is now appear
SongChiYoung Apr 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._summarized_chat_completion_context import (
SummarizedChatCompletionContext,
)
from ._token_limited_chat_completion_context import TokenLimitedChatCompletionContext
from ._unbounded_chat_completion_context import (
UnboundedChatCompletionContext,
)
from .conditions import (
ContextMessage,
SummarizngFunction,
TriggerMessage,
)

__all__ = [
"ChatCompletionContext",
@@ -13,4 +21,8 @@
"BufferedChatCompletionContext",
"TokenLimitedChatCompletionContext",
"HeadAndTailChatCompletionContext",
"SummarizedChatCompletionContext",
"ContextMessage",
"TriggerMessage",
"SummarizngFunction",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import List

from pydantic import BaseModel
from typing_extensions import Self

from autogen_core import ComponentModel

from .._component_config import Component
from ..models import LLMMessage
from ..tools._base import BaseTool
from ._chat_completion_context import ChatCompletionContext
from .conditions import MessageCompletionCondition, SummarizngFunction, SummaryFunction


class SummarizedChatCompletionContextConfig(BaseModel):
summarizing_func: ComponentModel
summarizing_condition: ComponentModel
initial_messages: List[LLMMessage] | None = None
non_summarized_messages: List[LLMMessage] | None = None


class SummarizedChatCompletionContext(ChatCompletionContext, Component[SummarizedChatCompletionContextConfig]):
"""A summarized chat completion context that summarizes the messages in the context
using a summarizing function. The summarizing function is set at initialization.
The summarizing condition is used to determine when to summarize the messages.
Args:
summarizing_func (Callable[[List[LLMMessage]], List[LLMMessage]]): The function to summarize the messages.
summarizing_condition (MessageCompletionCondition): The condition to determine when to summarize the messages.
initial_messages (List[LLMMessage] | None): The initial messages.
Example:
.. code-block:: python
from typing import List
from autogen_core.model_context import SummarizedChatCompletionContext
from autogen_core.models import LLMMessage
def summarizing_func(messages: List[LLMMessage]) -> List[LLMMessage]:
# Implement your summarizing function here.
return messages
summarizing_condition = MessageCompletionCondition()
context = SummarizedChatCompletionContext(
summarizing_func=summarizing_func,
summarizing_condition=summarizing_condition,
)
.. code-block:: python
import asyncio
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_agentchat.agents import AssistantAgent
from autogen_core.model_context import SummarizedChatCompletionContext
from autogen_core.model_context.conditions import MaxMessageCompletion
from autogen_ext.summary import buffered_summary, buffered_summarized_chat_completion_context
client = OpenAIChatCompletionClient(model="claude-3-haiku-20240307")
context = SummarizedChatCompletionContext(
summarizing_func=buffered_summary(buffer_count=2), summarizing_condition=MaxMessageCompletion(max_messages=2)
)
agent = AssistantAgent(
"helper", model_client=client, system_message="You are a helpful agent", model_context=context
)
"""

component_config_schema = SummarizedChatCompletionContextConfig
component_provider_override = "autogen_core.model_context.SummarizedChatCompletionContext"

def __init__(
self,
summarizing_func: SummaryFunction | SummarizngFunction,
summarizing_condition: MessageCompletionCondition,
initial_messages: List[LLMMessage] | None = None,
non_summarized_messages: List[LLMMessage] | None = None,
) -> None:
super().__init__(initial_messages)

self._non_summarized_messages: List[LLMMessage] = []
if non_summarized_messages is not None:
self._non_summarized_messages.extend(non_summarized_messages)

self._non_summarized_messages.extend(self._messages)

self._summarizing_func: SummaryFunction
if isinstance(summarizing_func, BaseTool):
# If the summarizing function is a tool, use it directly.
self._summarizing_func = summarizing_func
elif callable(summarizing_func):
self._summarizing_func = SummaryFunction(func=summarizing_func)
else:
raise ValueError("summarizing_func must be a callable or a tool.")
self._summarizing_condition = summarizing_condition

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the context."""
self._non_summarized_messages.append(message)
await super().add_message(message)

# Check if the summarizing condition is met.
await self._summarizing_condition(self._messages)
if self._summarizing_condition.triggered:
# If the condition is met, summarize the messages.
await self.summary()
await self._summarizing_condition.reset()

async def get_messages(self) -> List[LLMMessage]:
return self._messages

async def summary(self) -> None:
"""Summarize the messages in the context using the summarizing function."""
summarized_message = await self._summarizing_func.run(self._messages, self._non_summarized_messages)
self._messages = summarized_message

def _to_config(self) -> SummarizedChatCompletionContextConfig:
return SummarizedChatCompletionContextConfig(
summarizing_func=self._summarizing_func.dump_component(),
summarizing_condition=self._summarizing_condition.dump_component(),
initial_messages=self._initial_messages,
)

@classmethod
def _from_config(cls, config: SummarizedChatCompletionContextConfig) -> Self:
return cls(
summarizing_func=SummaryFunction.load_component(config.summarizing_func),
summarizing_condition=MessageCompletionCondition.load_component(config.summarizing_condition),
initial_messages=config.initial_messages,
non_summarized_messages=config.non_summarized_messages,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from ._base_condition import (
AndMessageCompletionCondition,
MessageCompletionCondition,
MessageCompletionException,
OrMessageCompletionCondition,
)
from ._contidions import (
ExternalMessageCompletion,
FunctionCallMessageCompletion,
MaxMessageCompletion,
SourceMatchMessageCompletion,
StopMessageCompletion,
TextMentionMessageCompletion,
TextMessageMessageCompletion,
TimeoutMessageCompletion,
TokenUsageMessageCompletion,
)
from ._summary_function import (
SummaryFunction,
)
from ._types import (
ContextMessage,
SummarizngFunction,
TriggerMessage,
)

__all__ = [
"MessageCompletionCondition",
"AndMessageCompletionCondition",
"OrMessageCompletionCondition",
"MessageCompletionException",
"ContextMessage",
"TriggerMessage",
"SummarizngFunction",
"StopMessageCompletion",
"MaxMessageCompletion",
"TextMentionMessageCompletion",
"TokenUsageMessageCompletion",
"TimeoutMessageCompletion",
"ExternalMessageCompletion",
"SourceMatchMessageCompletion",
"TextMessageMessageCompletion",
"FunctionCallMessageCompletion",
"SummaryFunction",
]

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import functools
import warnings
from textwrap import dedent
from typing import Any, List, Sequence

from pydantic import BaseModel
from typing_extensions import Self

from ..._component_config import Component
from ..._function_utils import (
get_typed_signature,
)
from ...code_executor._func_with_reqs import Import, import_to_str, to_code
from ...models import LLMMessage
from .base import BaseSummaryFunction, BaseSummaryAgent
from ._types import SummarizngFunction


class SummaryFunctionConfig(BaseModel):
"""Configuration for a summary function."""

source_code: str | None = None
agent: BaseSummaryAgent | None = None
name: str
global_imports: Sequence[Import]


class SummaryFunction(BaseSummaryFunction, Component[SummaryFunctionConfig]):
component_provider_override = "autogen_core.model_context.conditions.SummaryFunction"
component_config_schema = SummaryFunctionConfig

def __init__(
self,
func: SummarizngFunction | None = None,
agent: BaseSummaryAgent | None = None,
name: str | None = None,
global_imports: Sequence[Import] = [],
strict: bool = False,
) -> None:
self._func = func
self._agent = agent
self._global_imports = global_imports
if func is not None:
self._signature = get_typed_signature(func)
func_name = name or func.func.__name__ if isinstance(func, functools.partial) else name or func.__name__
if agent is not None:
if not isinstance(agent, BaseSummaryAgent):
raise TypeError(f"Expected a BaseChatAgent but got {type(agent)}")
func_name = name or agent.name
if func is None and agent is None:
raise ValueError("Either a function or an agent must be provided.")
if func is not None and agent is not None:
raise ValueError("Only one of a function or an agent can be provided.")
super().__init__(func_name)

async def run(self, messages: List[LLMMessage], non_summary_messages: List[LLMMessage]) -> List[LLMMessage]:
if self._func in not None:
result = self._func(messages, non_summary_messages)
if self._agent is not None:
result = await self._agent.run(task=messages, original_task=non_summary_messages)
return result

def _to_config(self) -> SummaryFunctionConfig:
if self._func is None:
return SummaryFunctionConfig(
source_code=dedent(to_code(self._func)),
global_imports=self._global_imports,
name=self.name,
)
if self._agent is not None:
return SummaryFunctionConfig(
agent=self._agent,
global_imports=self._global_imports,
name=self.name,
)

@classmethod
def _from_config(cls, config: SummaryFunctionConfig) -> Self:
exec_globals: dict[str, Any] = {}

# Execute imports first
for import_stmt in config.global_imports:
import_code = import_to_str(import_stmt)
try:
exec(import_code, exec_globals)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
f"Failed to import {import_code}: Module not found. Please ensure the module is installed."
) from e
except ImportError as e:
raise ImportError(f"Failed to import {import_code}: {str(e)}") from e
except Exception as e:
raise RuntimeError(f"Unexpected error while importing {import_code}: {str(e)}") from e

if config.source_code is not None:
warnings.warn(
"\n⚠️ SECURITY WARNING ⚠️\n"
"Loading a FunctionTool from config will execute code to import the provided global imports and and function code.\n"
"Only load configs from TRUSTED sources to prevent arbitrary code execution.",
UserWarning,
stacklevel=2,
)

# Execute function code
try:
exec(config.source_code, exec_globals)
func_name = config.source_code.split("def ")[1].split("(")[0]
except Exception as e:
raise ValueError(f"Could not compile and load function: {e}") from e

# Get function and verify it's callable
func: SummarizngFunction = exec_globals[func_name]
if not callable(func):
raise TypeError(f"Expected function but got {type(func)}")

return cls(func=func, name=config.name, global_imports=config.global_imports)
if config.agent is not None:
return cls(agent=config.agent, name=config.name, global_imports=config.global_imports)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Callable, List, Literal, Union

from pydantic import BaseModel, Field
from typing_extensions import Annotated

from ...models import (
AssistantMessage,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)


class TriggerMessage(BaseModel):
"""A message requesting trigger of a completion context."""

content: str
source: str
type: Literal["TriggerMessage"] = "TriggerMessage"


BaseContextMessage = Union[UserMessage, AssistantMessage]
BaseContextMessageTypes = (UserMessage, AssistantMessage)

LLMMessageInstance = (SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage)

ContextMessage = Annotated[Union[LLMMessage, TriggerMessage], Field(discriminator="type")]

SummarizngFunction = Callable[[List[LLMMessage], List[LLMMessage]], List[LLMMessage]]
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from ._base_condition import (
MessageCompletionException,
MessageCompletionCondition,
AndMessageCompletionCondition,
OrMessageCompletionCondition,

)
from ._base_summary_function import BaseSummaryFunction
from ._base_summary_agent import BaseSummaryAgent

__all__ = [
"MessageCompletionException",
"MessageCompletionCondition",
"AndMessageCompletionCondition",
"OrMessageCompletionCondition",
"BaseSummaryFunction",
"BaseSummaryAgent"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import asyncio
from abc import ABC, abstractmethod
from typing import List, Sequence

from pydantic import BaseModel
from typing_extensions import Self

from autogen_core import Component, ComponentBase, ComponentModel

from .._types import ContextMessage, TriggerMessage


class MessageCompletionException(BaseException): ...


class MessageCompletionCondition(ABC, ComponentBase[BaseModel]):
"""A stateful condition that determines when a message completion should be triggered.
A message completion condition is a callable that takes a sequence of ContextMessage objects
since the last time the condition was called, and returns a TriggerMessage if the
conversation should be terminated, or None otherwise.
Once a message completion condition has been reached, it must be reset before it can be used again.
Message completion conditions can be combined using the AND and OR operators.
Example:
.. code-block:: python
import asyncio
from autogen_core.model_context.conditions import (
MaxMessageCompletion,
TextMentionMessageCompletion,
)
async def main() -> None:
# Terminate the conversation after 10 turns or if the text "SUMMARY" is mentioned.
cond1 = MaxMessageCompletion(10) | TextMentionMessageCompletion("SUMMARY")
# Terminate the conversation after 10 turns and if the text "SUMMARY" is mentioned.
cond2 = MaxMessageCompletion(10) & TextMentionMessageCompletion("SUMMARY")
# ...
# Reset the message completion condition.
await cond1.reset()
await cond2.reset()
asyncio.run(main())
"""

component_type = "message_completion_condition"

@property
@abstractmethod
def triggered(self) -> bool:
"""Check if the trigger condition has been reached"""
...

@abstractmethod
async def __call__(self, messages: Sequence[ContextMessage]) -> TriggerMessage | None:
"""Check if the message completion should be triggered based on the messages received since the last call.
Args:
messages (Sequence[ContextMessage]): The messages received since the last call.
Returns:
TriggerMessage | None: The trigger message if the condition is met, or None if not.
Raises:
MessageCompletionException: If the message completion condition has already been reached."""
...

@abstractmethod
async def reset(self) -> None:
"""Reset the model completion condition."""
...

def __and__(self, other: "MessageCompletionCondition") -> "MessageCompletionCondition":
"""Combine two trigger conditions with an AND operation."""
return AndMessageCompletionCondition(self, other)

def __or__(self, other: "MessageCompletionCondition") -> "MessageCompletionCondition":
"""Combine two trigger conditions with an OR operation."""
return OrMessageCompletionCondition(self, other)


class AndMessageCompletionConditionConfig(BaseModel):
conditions: List[ComponentModel]


class AndMessageCompletionCondition(MessageCompletionCondition, Component[AndMessageCompletionConditionConfig]):
component_config_schema = AndMessageCompletionConditionConfig
component_type = "trigger"
component_provider_override = "autogen_core.model_context.conditions.AndMessageCompletionCondition"

def __init__(self, *conditions: MessageCompletionCondition) -> None:
self._conditions = conditions
self._trigger_messages: List[TriggerMessage] = []

@property
def triggered(self) -> bool:
return all(condition.triggered for condition in self._conditions)

async def __call__(self, messages: Sequence[ContextMessage]) -> TriggerMessage | None:
if self.triggered:
raise MessageCompletionException("Message completion condition has already been reached.")
# Check all remaining conditions.
trigger_messages = await asyncio.gather(
*[condition(messages) for condition in self._conditions if not condition.triggered]
)
# Collect stop messages.
for trigger_message in trigger_messages:
if trigger_message is not None:
self._trigger_messages.append(trigger_message)
if any(trigger_message is None for trigger_message in trigger_messages):
# If any remaining condition has not reached termination, it is not terminated.
return None
content = ", ".join(trigger_message.content for trigger_message in self._trigger_messages)
source = ", ".join(trigger_message.source for trigger_message in self._trigger_messages)
return TriggerMessage(content=content, source=source)

async def reset(self) -> None:
for condition in self._conditions:
await condition.reset()
self._trigger_messages.clear()

def _to_config(self) -> AndMessageCompletionConditionConfig:
"""Convert the AND trigger condition to a config."""
return AndMessageCompletionConditionConfig(
conditions=[condition.dump_component() for condition in self._conditions]
)

@classmethod
def _from_config(cls, config: AndMessageCompletionConditionConfig) -> Self:
"""Create an AND trigger condition from a config."""
conditions = [
MessageCompletionCondition.load_component(condition_model) for condition_model in config.conditions
]
return cls(*conditions)


class OrMessageCompletionConditionConfig(BaseModel):
conditions: List[ComponentModel]
"""List of termination conditions where any one being satisfied is sufficient."""


class OrMessageCompletionCondition(MessageCompletionCondition, Component[OrMessageCompletionConditionConfig]):
component_config_schema = OrMessageCompletionConditionConfig
component_type = "trigger"
component_provider_override = "autogen_core.model_context.conditions.OrTerminationCondition"

def __init__(self, *conditions: MessageCompletionCondition) -> None:
self._conditions = conditions

@property
def triggered(self) -> bool:
return any(condition.triggered for condition in self._conditions)

async def __call__(self, messages: Sequence[ContextMessage]) -> TriggerMessage | None:
if self.triggered:
raise RuntimeError("Message completion condition has already been reached")
trigger_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])
trigger_messages_filter = [
trigger_message for trigger_message in trigger_messages if trigger_message is not None
]
if len(trigger_messages_filter) > 0:
content = ", ".join(trigger_message.content for trigger_message in trigger_messages_filter)
source = ", ".join(trigger_message.source for trigger_message in trigger_messages_filter)
return TriggerMessage(content=content, source=source)
return None

async def reset(self) -> None:
for condition in self._conditions:
await condition.reset()

def _to_config(self) -> OrMessageCompletionConditionConfig:
"""Convert the OR trigger condition to a config."""
return OrMessageCompletionConditionConfig(
conditions=[condition.dump_component() for condition in self._conditions]
)

@classmethod
def _from_config(cls, config: OrMessageCompletionConditionConfig) -> Self:
"""Create an OR trigger condition from a config."""
conditions = [
MessageCompletionCondition.load_component(condition_model) for condition_model in config.conditions
]
return cls(*conditions)
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from abc import ABC, abstractmethod
from typing import List, Sequence
from pydantic import BaseModel

from autogen_core.models import ChatCompletionClient
from autogen_core import ComponentBase, CancellationToken
from ....models import LLMMessage, SystemMessage
from autogen_agentchat.utils import remove_images

class BaseSummaryAgent(ABC, ComponentBase[BaseModel]):
"""
Base class for summary agents.
"""

component_type = "summary_agent"

def __init__(
self,
name: str,
model_client: ChatCompletionClient,
cancellation_token: CancellationToken | None = None,
*,
system_message: str,
):
self._name = name
self._model_client = model_client
self._system_message = [SystemMessage(content=system_message)]
self._cancellation_token = cancellation_token


@property
def name(self) -> str:
"""The name of the agent."""
return self._name

@abstractmethod
def run(
self,
task: List[LLMMessage] | None = None,
original_task: List[LLMMessage] | None = None,
) -> List[LLMMessage]:
"""
Run the summary agent.
Args:
task: The task to run.
original_task: The original task to run.
Returns:
The result of the run.
"""
...

@staticmethod
def _get_compatible_context(model_client: ChatCompletionClient, messages: List[LLMMessage]) -> Sequence[LLMMessage]:
"""Ensure that the messages are compatible with the underlying client, by removing images if needed."""
if model_client.model_info["vision"]:
return messages
else:
return remove_images(messages)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, List, Mapping

from pydantic import BaseModel

from .... import EVENT_LOGGER_NAME
from ...._component_config import ComponentBase
from ....models import LLMMessage

logger = logging.getLogger(EVENT_LOGGER_NAME)


class BaseSummaryFunction(ABC, ComponentBase[BaseModel]):
component_type = "summary_function"

def __init__(
self,
name: str,
) -> None:
self._name = name

@property
def name(self) -> str:
return self._name

@abstractmethod
def run(self, messages: List[LLMMessage], non_summary_messages: List[LLMMessage]) -> List[LLMMessage]: ...

def save_state_json(self) -> Mapping[str, Any]:
return {}

def load_state_json(self, state: Mapping[str, Any]) -> None:
pass
459 changes: 459 additions & 0 deletions python/packages/autogen-core/tests/test_summary_conditions.py

Large diffs are not rendered by default.

129 changes: 129 additions & 0 deletions python/packages/autogen-core/tests/test_summary_model_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import functools
from typing import Any, Generator, List, cast
from unittest.mock import AsyncMock, Mock

import pytest
from autogen_core.code_executor import ImportFromModule
from autogen_core.model_context import SummarizedChatCompletionContext
from autogen_core.model_context.conditions import ExternalMessageCompletion, SummaryFunction
from autogen_core.models import AssistantMessage, LLMMessage, UserMessage


def mock_summarizing_func(messages: List[LLMMessage], non_summarized_messages: List[LLMMessage]) -> List[LLMMessage]:
return [UserMessage(source="user", content="summarized")]


@pytest.fixture
def mock_condition() -> Generator[Any, None, None]:
mock = AsyncMock(spec=ExternalMessageCompletion)
mock.triggered = False # type: ignore
mock.reset = AsyncMock() # type: ignore
yield mock


@pytest.mark.asyncio
async def test_initialize_with_none_initial_messages() -> None:
context = SummarizedChatCompletionContext(
summarizing_func=mock_summarizing_func, summarizing_condition=ExternalMessageCompletion(), initial_messages=None
)
messages = await context.get_messages()
assert messages == []


@pytest.mark.asyncio
async def test_initialize_with_initial_messages() -> None:
initial_msgs = cast(
List[LLMMessage],
[
UserMessage(source="user", content="test1"),
AssistantMessage(source="assistant", content="test2"),
],
)
context = SummarizedChatCompletionContext(
summarizing_func=mock_summarizing_func,
summarizing_condition=ExternalMessageCompletion(),
initial_messages=initial_msgs,
)
messages = await context.get_messages()
assert messages == initial_msgs


@pytest.mark.asyncio
async def test_add_message(mock_condition: Any) -> None:
context = SummarizedChatCompletionContext(
summarizing_func=mock_summarizing_func, summarizing_condition=mock_condition
)

message = UserMessage(source="user", content="test")
await context.add_message(message)

messages = await context.get_messages()
assert message in messages
mock_condition.assert_called_once()


@pytest.mark.asyncio
async def test_summary_called_when_condition_triggered(mock_condition: Any) -> None:
mock_condition.triggered = True
context = SummarizedChatCompletionContext(
summarizing_func=mock_summarizing_func, summarizing_condition=mock_condition
)

await context.add_message(UserMessage(source="user", content="test"))
messages = await context.get_messages()
mock_condition.reset.assert_called_once()
assert len(messages) == 1
assert messages == [UserMessage(source="user", content="summarized")]


## test summary function
def test_summary_function_init() -> None:
def sample_func(messages: List[LLMMessage], non_summary_messages: List[LLMMessage]) -> List[LLMMessage]:
return messages

# Test with basic function
sf = SummaryFunction(sample_func)
assert sf.name == "sample_func"

# Test with custom name
sf = SummaryFunction(sample_func, name="custom_name")
assert sf.name == "custom_name"

# Test with partial function
partial_func = functools.partial(sample_func)
sf = SummaryFunction(partial_func)
assert sf.name == "sample_func"


def test_summary_function_run() -> None:
mock_messages = cast(List[LLMMessage], [Mock(spec=LLMMessage)])
mock_non_summary = cast(List[LLMMessage], [Mock(spec=LLMMessage)])

def sample_func(messages: List[LLMMessage], non_summary_messages: List[LLMMessage]) -> List[LLMMessage]:
return messages

sf = SummaryFunction(sample_func)
result = sf.run(mock_messages, mock_non_summary)
assert result == mock_messages


def test_summary_function_dump_and_load() -> None:
def sample_func(messages: List[LLMMessage], non_summary_messages: List[LLMMessage]) -> List[LLMMessage]:
return messages

import_list = ImportFromModule("typing", ["List"])
import_llmmessage = ImportFromModule("autogen_core.models", ["LLMMessage"])

sf = SummaryFunction(sample_func, global_imports=[import_list, import_llmmessage])
config = sf.dump_component()

assert config.config["name"] == "sample_func"
assert (
config.config["source_code"]
== "def sample_func(messages: List[LLMMessage], non_summary_messages: List[LLMMessage]) -> List[LLMMessage]:\n return messages\n"
)

loaded_sf = SummaryFunction.load_component(config)
assert loaded_sf.name == "sample_func"
assert loaded_sf._func.__code__.co_code == sf._func.__code__.co_code # pyright: ignore[reportPrivateUsage]
assert loaded_sf._signature == sf._signature # pyright: ignore[reportPrivateUsage]
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .buffered_summary import buffered_summarized_chat_completion_context, buffered_summary

__all__ = [
"buffered_summary",
"buffered_summarized_chat_completion_context",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import List

from autogen_core.model_context import (
SummarizedChatCompletionContext,
SummarizngFunction,
)
from autogen_core.model_context.conditions import MaxMessageCompletion
from autogen_core.models import LLMMessage


def buffered_summary(buffer_count: int) -> SummarizngFunction:
"""Create a buffered summary function.
This function summarizes the last `buffer_count` messages.
It is used to create a buffered summarized chat completion context.
Args:
buffer_count (int): The size of the buffer.
Returns:
SummarizngFunction: The buffered summary function.
"""

def _buffered_summary(
messages: List[LLMMessage],
non_summarized_messages: List[LLMMessage],
) -> List[LLMMessage]:
"""Summarize the last `buffer_count` messages."""
if len(messages) > buffer_count:
return messages[-buffer_count:]
return messages

return _buffered_summary


def buffered_summarized_chat_completion_context(
buffer_count: int,
max_messages: int | None = None,
initial_messages: List[LLMMessage] | None = None,
) -> SummarizedChatCompletionContext:
"""Build a buffered summarized chat completion context.
Args:
buffer_count (int): The size of the buffer.
trigger_count (int | None): The size of the trigger. When is None, the trigger count is set to the buffer count.
initial_messages (List[LLMMessage] | None): The initial messages.
Returns:
SummarizedChatCompletionContext: The buffered summarized chat completion context.
"""

if max_messages is None:
max_messages = buffer_count

return SummarizedChatCompletionContext(
summarizing_func=buffered_summary(buffer_count),
summarizing_condition=MaxMessageCompletion(
max_messages=max_messages,
),
initial_messages=initial_messages,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Dict, List, Optional, Union

from autogen_core.model_context.conditions.base import BaseSummaryAgent
from autogen_core.models import ChatCompletionClient, LLMMessage, AssistantMessage
from autogen_core import CancellationToken, ComponentModel

from pydantic import BaseModel


class BufferedSummaryAgentConfig(BaseModel):
"""The declarative configuration for the buffered summary agent agent."""

name: str
model_client: ComponentModel
system_message: str | None = None
cancellation_token: CancellationToken | None = None
model_client_stream: bool = False

class BufferedSummaryAgent(BaseSummaryAgent):
component_config_schema = BufferedSummaryAgentConfig
component_provider_override = "autogen_ext.summary.BufferedSummaryAgent"
"""A buffered summary agent that summarizes the messages in the context
using a LLM.
"""

def __init__(
self,
name: str,
model_client: ChatCompletionClient,
summary_start: int = 0,
summary_end: int = 0,
*,
system_message: str | None = None,
cancellation_token: CancellationToken | None = None
) -> None:
if system_message is None:
summary_prompt="Summarize the conversation so far for your own memory",
summary_format="This portion of conversation has been summarized as follow: {summary}",
system_message = f"{summary_prompt}\n{summary_format}"
super().__init__(
name=name,
model_client=model_client,
system_message=system_message,
cancellation_token=cancellation_token,
)

self._summary_start = summary_start
self._summary_end = summary_end

async def run(
self,
task: List[LLMMessage] | None = None,
original_task: List[LLMMessage] | None = None,
) -> List[LLMMessage]:
"""Run the summary agent."""
if task is None:
task = []
if self._summary_start > 0 and self._summary_end < 0:
task = task[self._summary_start:self._summary_end]
elif self._summary_start > 0:
task = task[self._summary_start:]
elif self._summary_end < 0:
task = task[:self._summary_end]

task = self._system_message + task
task = BaseSummaryAgent._get_compatible_context(
self._model_client, task
)

result = await self._model_client.create(
messages=task,
cancellation_token=self._cancellation_token,
)

return [AssistantMessage(content=result.content, source="summary")]